当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.Tensor.get_shape用法及代码示例


用法

get_shape()

返回

返回表示此张量形状的 tf.TensorShape

在即刻执行中,形状总是fully-known。

a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(a.shape)
(2, 3)

tf.Tensor.get_shape() 等同于 tf.Tensor.shape

tf.function 中执行或使用 tf.keras.Input 构建模型时,Tensor.shape 可能会返回部分形状(包括未知尺寸的 None)。有关详细信息,请参阅tf.TensorShape

inputs = tf.keras.Input(shape = [10])
# Unknown batch size
print(inputs.shape)
(None, 10)

使用为每个 tf.Operation 注册的形状推断函数计算形状。

返回的tf.TensorShape 在构建时确定,而不执行底层内核。它不是 tf.Tensor。如果您需要形状张量,请将 tf.TensorShape 转换为 tf.constant,或使用 tf.shape(tensor) 函数,该函数在执行时返回张量的形状。

这对于调试和提供早期错误很有用。例如,当跟踪 tf.function 时,没有执行任何操作,形状可能是未知的(有关详细信息,请参阅具体函数指南)。

@tf.function
def my_matmul(a, b):
  result = a@b
  # the `print` executes during tracing.
  print("Result shape:", result.shape)
  return result

形状推断函数尽可能地传播形状:

f = my_matmul.get_concrete_function(
  tf.TensorSpec([None,3]),
  tf.TensorSpec([3,5]))
Result shape:(None, 5)

如果可以检测到形状不匹配,则跟踪可能会失败:

cf = my_matmul.get_concrete_function(
  tf.TensorSpec([None,3]),
  tf.TensorSpec([4,5]))
Traceback (most recent call last):

ValueError:Dimensions must be equal, but are 3 and 4 for 'matmul' (op:
'MatMul') with input shapes:[?,3], [4,5].

在某些情况下,推断的形状可能具有未知尺寸。如果调用者有关于这些维度值的附加信息,tf.ensure_shapeTensor.set_shape() 可用于增强推断的形状。

@tf.function
def my_fun(a):
  a = tf.ensure_shape(a, [5, 5])
  # the `print` executes during tracing.
  print("Result shape:", a.shape)
  return a
cf = my_fun.get_concrete_function(
  tf.TensorSpec([None, None]))
Result shape:(5, 5)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.Tensor.get_shape。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。