当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.Tensor.set_shape用法及代码示例


用法

set_shape(
    shape
)

参数

  • shape 一个 TensorShape 表示这个张量的形状,一个 TensorShapeProto ,一个列表,一个元组,或无。

抛出

  • ValueError 如果shape 与此张量的当前形状不兼容。

更新此张量的形状。

注意:建议使用 tf.ensure_shape 而不是 Tensor.set_shape ,因为 tf.ensure_shape 可以更好地检查编程错误并且可以为编译器优化提供保证。

通过即刻的执行,这作为形状断言运行。这里的形状匹配:

t = tf.constant([[1,2,3]])
t.set_shape([1, 3])

在新形状中传递 None 允许该轴的任何值:

t.set_shape([1,None])

如果传递了不兼容的形状,则会引发错误。

t.set_shape([1,5])
Traceback (most recent call last):

ValueError:Tensor's shape (1, 3) is not compatible with supplied
shape [1, 5]

tf.function 中执行或使用 tf.keras.Input 构建模型时,Tensor.set_shape 会将给定的 shape 与该张量的当前形状合并,并将张量的形状设置为合并值(参见 tf.TensorShape.merge_with for细节):

t = tf.keras.Input(shape=[None, None, 3])
print(t.shape)
(None, None, None, 3)

设置为 None 的尺寸不会更新:

t.set_shape([None, 224, 224, None])
print(t.shape)
(None, 224, 224, 3)

其主要用例是提供无法仅从图形中推断出的附加形状信息。

例如,如果您知道数据集中的所有图像都具有 [28,28,3] 形状,则可以使用 tf.set_shape 对其进行设置:

@tf.function
def load_image(filename):
  raw = tf.io.read_file(filename)
  image = tf.image.decode_png(raw, channels=3)
  # the `print` executes during tracing.
  print("Initial shape:", image.shape)
  image.set_shape([28, 28, 3])
  print("Final shape:", image.shape)
  return image

跟踪函数,详见具体函数指南。

cf = load_image.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape:(28, 28, 3)

同样,tf.io.parse_tensor 函数可以返回任何形状的张量,即使 tf.rank 是未知的。如果您知道所有序列化张量都是 2d,请使用 set_shape 进行设置:

@tf.function
def my_parse(string_tensor):
  result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
  # the `print` executes during tracing.
  print("Initial shape:", result.shape)
  result.set_shape([None, None])
  print("Final shape:", result.shape)
  return result

跟踪函数

concrete_parse = my_parse.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape: <unknown>
Final shape: (None, None)

确保它有效:

t = tf.ones([5,3], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
print(serialized.dtype)
<dtype:'string'>
print(serialized.shape)
()
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 3)

警告:set_shape 确保应用的形状与现有形状兼容,但不会在运行时检查。设置不正确的形状可能会导致 statically-known 图形与张量的运行时值不一致。对于形状的运行时验证,请改用tf.ensure_shape。它还修改了张量的shape

# Serialize a rank-3 tensor
t = tf.ones([5,5,5], dtype=tf.float32)
serialized = tf.io.serialize_tensor(t)
# The function still runs, even though it `set_shape([None,None])`
t2 = concrete_parse(serialized)
print(t2.shape)
(5, 5, 5)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.Tensor.set_shape。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。