當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.Tensor.set_shape用法及代碼示例


用法

set_shape(
    shape
)

參數

  • shape 一個 TensorShape 表示這個張量的形狀,一個 TensorShapeProto ,一個列表,一個元組,或無。

拋出

  • ValueError 如果shape 與此張量的當前形狀不兼容。

更新此張量的形狀。

注意:建議使用 tf.ensure_shape 而不是 Tensor.set_shape ,因為 tf.ensure_shape 可以更好地檢查編程錯誤並且可以為編譯器優化提供保證。

通過即刻的執行,這作為形狀斷言運行。這裏的形狀匹配:

t = tf.constant([[1,2,3]])
t.set_shape([1, 3])

在新形狀中傳遞 None 允許該軸的任何值:

t.set_shape([1,None])

如果傳遞了不兼容的形狀,則會引發錯誤。

t.set_shape([1,5])
Traceback (most recent call last):

ValueError:Tensor's shape (1, 3) is not compatible with supplied
shape [1, 5]

tf.function 中執行或使用 tf.keras.Input 構建模型時,Tensor.set_shape 會將給定的 shape 與該張量的當前形狀合並,並將張量的形狀設置為合並值(參見 tf.TensorShape.merge_with for細節):

t = tf.keras.Input(shape=[None, None, 3])
print(t.shape)
(None, None, None, 3)

設置為 None 的尺寸不會更新:

t.set_shape([None, 224, 224, None])
print(t.shape)
(None, 224, 224, 3)

其主要用例是提供無法僅從圖形中推斷出的附加形狀信息。

例如,如果您知道數據集中的所有圖像都具有 [28,28,3] 形狀,則可以使用 tf.set_shape 對其進行設置:

@tf.function
def load_image(filename):
  raw = tf.io.read_file(filename)
  image = tf.image.decode_png(raw, channels=3)
  # the `print` executes during tracing.
  print("Initial shape:", image.shape)
  image.set_shape([28, 28, 3])
  print("Final shape:", image.shape)
  return image

跟蹤函數,詳見具體函數指南。

cf = load_image.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape:(28, 28, 3)

同樣,tf.io.parse_tensor 函數可以返回任何形狀的張量,即使 tf.rank 是未知的。如果您知道所有序列化張量都是 2d,請使用 set_shape 進行設置:

@tf.function
def my_parse(string_tensor):
  result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
  # the `print` executes during tracing.
  print("Initial shape:", result.shape)
  result.set_shape([None, None])
  print("Final shape:", result.shape)
  return result

跟蹤函數

concrete_parse = my_parse.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape: <unknown>
Final shape: (None, None)

確保它有效:

t = tf.ones([5,3], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
print(serialized.dtype)
<dtype:'string'>
print(serialized.shape)
()
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 3)

警告:set_shape 確保應用的形狀與現有形狀兼容,但不會在運行時檢查。設置不正確的形狀可能會導致 statically-known 圖形與張量的運行時值不一致。對於形狀的運行時驗證,請改用tf.ensure_shape。它還修改了張量的shape

# Serialize a rank-3 tensor
t = tf.ones([5,5,5], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
# The function still runs, even though it `set_shape([None,None])`
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 5, 5)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.Tensor.set_shape。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。