用法
set_shape(
shape
)
參數
-
shape
一個TensorShape
表示這個張量的形狀,一個TensorShapeProto
,一個列表,一個元組,或無。
拋出
-
ValueError
如果shape
與此張量的當前形狀不兼容。
更新此張量的形狀。
注意:建議使用 tf.ensure_shape
而不是 Tensor.set_shape
,因為 tf.ensure_shape
可以更好地檢查編程錯誤並且可以為編譯器優化提供保證。
通過即刻的執行,這作為形狀斷言運行。這裏的形狀匹配:
t = tf.constant([[1,2,3]])
t.set_shape([1, 3])
在新形狀中傳遞 None
允許該軸的任何值:
t.set_shape([1,None])
如果傳遞了不兼容的形狀,則會引發錯誤。
t.set_shape([1,5])
Traceback (most recent call last):
ValueError:Tensor's shape (1, 3) is not compatible with supplied
shape [1, 5]
在 tf.function
中執行或使用 tf.keras.Input
構建模型時,Tensor.set_shape
會將給定的 shape
與該張量的當前形狀合並,並將張量的形狀設置為合並值(參見 tf.TensorShape.merge_with
for細節):
t = tf.keras.Input(shape=[None, None, 3])
print(t.shape)
(None, None, None, 3)
設置為 None
的尺寸不會更新:
t.set_shape([None, 224, 224, None])
print(t.shape)
(None, 224, 224, 3)
其主要用例是提供無法僅從圖形中推斷出的附加形狀信息。
例如,如果您知道數據集中的所有圖像都具有 [28,28,3] 形狀,則可以使用 tf.set_shape
對其進行設置:
@tf.function
def load_image(filename):
raw = tf.io.read_file(filename)
image = tf.image.decode_png(raw, channels=3)
# the `print` executes during tracing.
print("Initial shape:", image.shape)
image.set_shape([28, 28, 3])
print("Final shape:", image.shape)
return image
跟蹤函數,詳見具體函數指南。
cf = load_image.get_concrete_function(
tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape:(28, 28, 3)
同樣,tf.io.parse_tensor
函數可以返回任何形狀的張量,即使 tf.rank
是未知的。如果您知道所有序列化張量都是 2d,請使用 set_shape
進行設置:
@tf.function
def my_parse(string_tensor):
result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
# the `print` executes during tracing.
print("Initial shape:", result.shape)
result.set_shape([None, None])
print("Final shape:", result.shape)
return result
跟蹤函數
concrete_parse = my_parse.get_concrete_function(
tf.TensorSpec([], dtype=tf.string))
Initial shape: <unknown>
Final shape: (None, None)
確保它有效:
t = tf.ones([5,3], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
print(serialized.dtype)
<dtype:'string'>
print(serialized.shape)
()
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 3)
警告:set_shape
確保應用的形狀與現有形狀兼容,但不會在運行時檢查。設置不正確的形狀可能會導致 statically-known 圖形與張量的運行時值不一致。對於形狀的運行時驗證,請改用tf.ensure_shape
。它還修改了張量的shape
。
# Serialize a rank-3 tensor
t = tf.ones([5,5,5], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
# The function still runs, even though it `set_shape([None,None])`
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 5, 5)
相關用法
- Python tf.Tensor.__rsub__用法及代碼示例
- Python tf.Tensor.__lt__用法及代碼示例
- Python tf.Tensor.__abs__用法及代碼示例
- Python tf.Tensor.ref用法及代碼示例
- Python tf.Tensor.__getitem__用法及代碼示例
- Python tf.Tensor.__ge__用法及代碼示例
- Python tf.Tensor.__rmatmul__用法及代碼示例
- Python tf.Tensor.__bool__用法及代碼示例
- Python tf.Tensor.get_shape用法及代碼示例
- Python tf.Tensor.__xor__用法及代碼示例
- Python tf.Tensor.__sub__用法及代碼示例
- Python tf.Tensor.__rpow__用法及代碼示例
- Python tf.Tensor.__gt__用法及代碼示例
- Python tf.Tensor.__le__用法及代碼示例
- Python tf.Tensor.__pow__用法及代碼示例
- Python tf.Tensor.__matmul__用法及代碼示例
- Python tf.TensorSpec.from_spec用法及代碼示例
- Python tf.Tensor用法及代碼示例
- Python tf.TensorSpec.from_tensor用法及代碼示例
- Python tf.TensorShape.merge_with用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.Tensor.set_shape。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。