用法
set_shape(
shape
)
参数
-
shape
一个TensorShape
表示这个张量的形状,一个TensorShapeProto
,一个列表,一个元组,或无。
抛出
-
ValueError
如果shape
与此张量的当前形状不兼容。
更新此张量的形状。
注意:建议使用 tf.ensure_shape
而不是 Tensor.set_shape
,因为 tf.ensure_shape
可以更好地检查编程错误并且可以为编译器优化提供保证。
通过即刻的执行,这作为形状断言运行。这里的形状匹配:
t = tf.constant([[1,2,3]])
t.set_shape([1, 3])
在新形状中传递 None
允许该轴的任何值:
t.set_shape([1,None])
如果传递了不兼容的形状,则会引发错误。
t.set_shape([1,5])
Traceback (most recent call last):
ValueError:Tensor's shape (1, 3) is not compatible with supplied
shape [1, 5]
在 tf.function
中执行或使用 tf.keras.Input
构建模型时,Tensor.set_shape
会将给定的 shape
与该张量的当前形状合并,并将张量的形状设置为合并值(参见 tf.TensorShape.merge_with
for细节):
t = tf.keras.Input(shape=[None, None, 3])
print(t.shape)
(None, None, None, 3)
设置为 None
的尺寸不会更新:
t.set_shape([None, 224, 224, None])
print(t.shape)
(None, 224, 224, 3)
其主要用例是提供无法仅从图形中推断出的附加形状信息。
例如,如果您知道数据集中的所有图像都具有 [28,28,3] 形状,则可以使用 tf.set_shape
对其进行设置:
@tf.function
def load_image(filename):
raw = tf.io.read_file(filename)
image = tf.image.decode_png(raw, channels=3)
# the `print` executes during tracing.
print("Initial shape:", image.shape)
image.set_shape([28, 28, 3])
print("Final shape:", image.shape)
return image
跟踪函数,详见具体函数指南。
cf = load_image.get_concrete_function(
tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape:(28, 28, 3)
同样,tf.io.parse_tensor
函数可以返回任何形状的张量,即使 tf.rank
是未知的。如果您知道所有序列化张量都是 2d,请使用 set_shape
进行设置:
@tf.function
def my_parse(string_tensor):
result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
# the `print` executes during tracing.
print("Initial shape:", result.shape)
result.set_shape([None, None])
print("Final shape:", result.shape)
return result
跟踪函数
concrete_parse = my_parse.get_concrete_function(
tf.TensorSpec([], dtype=tf.string))
Initial shape: <unknown>
Final shape: (None, None)
确保它有效:
t = tf.ones([5,3], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
print(serialized.dtype)
<dtype:'string'>
print(serialized.shape)
()
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 3)
警告:set_shape
确保应用的形状与现有形状兼容,但不会在运行时检查。设置不正确的形状可能会导致 statically-known 图形与张量的运行时值不一致。对于形状的运行时验证,请改用tf.ensure_shape
。它还修改了张量的shape
。
# Serialize a rank-3 tensor
t = tf.ones([5,5,5], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
# The function still runs, even though it `set_shape([None,None])`
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 5, 5)
相关用法
- Python tf.Tensor.__rsub__用法及代码示例
- Python tf.Tensor.__lt__用法及代码示例
- Python tf.Tensor.__abs__用法及代码示例
- Python tf.Tensor.ref用法及代码示例
- Python tf.Tensor.__getitem__用法及代码示例
- Python tf.Tensor.__ge__用法及代码示例
- Python tf.Tensor.__rmatmul__用法及代码示例
- Python tf.Tensor.__bool__用法及代码示例
- Python tf.Tensor.get_shape用法及代码示例
- Python tf.Tensor.__xor__用法及代码示例
- Python tf.Tensor.__sub__用法及代码示例
- Python tf.Tensor.__rpow__用法及代码示例
- Python tf.Tensor.__gt__用法及代码示例
- Python tf.Tensor.__le__用法及代码示例
- Python tf.Tensor.__pow__用法及代码示例
- Python tf.Tensor.__matmul__用法及代码示例
- Python tf.TensorSpec.from_spec用法及代码示例
- Python tf.Tensor用法及代码示例
- Python tf.TensorSpec.from_tensor用法及代码示例
- Python tf.TensorShape.merge_with用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.Tensor.set_shape。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。