当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.ensure_shape用法及代码示例


更新张量的形状并在运行时检查该形状是否保持不变。

用法

tf.ensure_shape(
    x, shape, name=None
)

参数

  • x 一个Tensor
  • shape 一个 TensorShape 表示这个张量的形状,一个 TensorShapeProto ,一个列表,一个元组,或无。
  • name 此操作的名称(可选)。默认为"EnsureShape"。

返回

  • 一个Tensor。具有与 x 相同的类型和内容。

抛出

执行时,此操作断言输入张量 x 的形状与 shape 参数兼容。有关详细信息,请参阅tf.TensorShape.is_compatible_with

x = tf.constant([[1, 2, 3],
                 [4, 5, 6]])
x = tf.ensure_shape(x, [2, 3])

对未知尺寸使用None

x = tf.ensure_shape(x, [None, 3])
x = tf.ensure_shape(x, [2, None])

如果张量的形状与 shape 参数不兼容,则会引发错误:

x = tf.ensure_shape(x, [5])
Traceback (most recent call last):

tf.errors.InvalidArgumentError:Shape of tensor dummy_input [3] is not
  compatible with expected shape [5]. [Op:EnsureShape]

在图形构建期间(通常跟踪 tf.function ),tf.ensure_shape 通过合并两个形状来更新结果张量的 static-shape。有关详细信息,请参阅 tf.TensorShape.merge_with。

当您知道 TensorFlow 无法静态确定的形状时,这非常有用。

以下简单的 tf.function 在应用 ensure_shape 之前和之后打印输入张量的 static-shape。

@tf.function
def f(tensor):
  print("Static-shape before:", tensor.shape)
  tensor = tf.ensure_shape(tensor, [None, 3])
  print("Static-shape after:", tensor.shape)
  return tensor

这可以让您在跟踪函数时看到tf.ensure_shape 的效果:

>>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
Static-shape before:(None, None)
Static-shape after:(None, 3)
cf(tf.zeros([3, 3])) # Passes
cf(tf.constant([1, 2, 3])) # fails
Traceback (most recent call last):

InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].

上面的示例引发了 tf.errors.InvalidArgumentError ,因为 x 的形状 (3,)shape 参数 (None, 3) 不兼容

tf.functionv1.Graph 上下文中,它会检查构建时间和运行时形状。这比仅检查构建时间形状的 tf.Tensor.set_shape 更严格。

注意:这与tf.Tensor.set_shape 的不同之处在于它设置了结果张量的静态形状并在运行时强制执行,如果张量的运行时形状与指定的形状不兼容,则会引发错误。 tf.Tensor.set_shape 设置张量的静态形状而不在运行时强制执行,这可能导致张量的statically-known 形状与张量的运行时值不一致。

例如,加载已知大小的图像:

@tf.function
def decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape:", image.shape)
  image = tf.ensure_shape(image,[28, 28, 3])
  print("Final shape:", image.shape)
  return image

跟踪函数时,没有执行任何操作,形状可能是未知的。有关详细信息,请参阅具体函数指南。

concrete_decode = decode_image.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape: (28, 28, 3)
image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
image = tf.cast(image,tf.uint8)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
print(image2.shape)
(28, 28, 3)
image = tf.concat([image,image], axis=0)
print(image.shape)
(56, 28, 3)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
Traceback (most recent call last):

tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not
  compatible with expected shape [28,28,3].

警告:如果您不使用tf.ensure_shape 的结果,则检查可能不会运行。

@tf.function
def bad_decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape:", image.shape)
  # BAD:forgot to use the returned tensor.
  tf.ensure_shape(image,[28, 28, 3])
  print("Final shape:", image.shape)
  return image
image = bad_decode_image(png)
Initial shape: (None, None, 3)
Final shape: (None, None, 3)
print(image.shape)
(56, 28, 3)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.ensure_shape。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。