當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.Tensor.get_shape用法及代碼示例


用法

get_shape()

返回

返回表示此張量形狀的 tf.TensorShape

在即刻執行中,形狀總是fully-known。

a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(a.shape)
(2, 3)

tf.Tensor.get_shape() 等同於 tf.Tensor.shape

tf.function 中執行或使用 tf.keras.Input 構建模型時,Tensor.shape 可能會返回部分形狀(包括未知尺寸的 None)。有關詳細信息,請參閱tf.TensorShape

inputs = tf.keras.Input(shape = [10])
# Unknown batch size
print(inputs.shape)
(None, 10)

使用為每個 tf.Operation 注冊的形狀推斷函數計算形狀。

返回的tf.TensorShape 在構建時確定,而不執行底層內核。它不是 tf.Tensor。如果您需要形狀張量,請將 tf.TensorShape 轉換為 tf.constant,或使用 tf.shape(tensor) 函數,該函數在執行時返回張量的形狀。

這對於調試和提供早期錯誤很有用。例如,當跟蹤 tf.function 時,沒有執行任何操作,形狀可能是未知的(有關詳細信息,請參閱具體函數指南)。

@tf.function
def my_matmul(a, b):
  result = a@b
  # the `print` executes during tracing.
  print("Result shape:", result.shape)
  return result

形狀推斷函數盡可能地傳播形狀:

f = my_matmul.get_concrete_function(
  tf.TensorSpec([None,3]),
  tf.TensorSpec([3,5]))
Result shape:(None, 5)

如果可以檢測到形狀不匹配,則跟蹤可能會失敗:

cf = my_matmul.get_concrete_function(
  tf.TensorSpec([None,3]),
  tf.TensorSpec([4,5]))
Traceback (most recent call last):

ValueError:Dimensions must be equal, but are 3 and 4 for 'matmul' (op:
'MatMul') with input shapes:[?,3], [4,5].

在某些情況下,推斷的形狀可能具有未知尺寸。如果調用者有關於這些維度值的附加信息,tf.ensure_shapeTensor.set_shape() 可用於增強推斷的形狀。

@tf.function
def my_fun(a):
  a = tf.ensure_shape(a, [5, 5])
  # the `print` executes during tracing.
  print("Result shape:", a.shape)
  return a
cf = my_fun.get_concrete_function(
  tf.TensorSpec([None, None]))
Result shape:(5, 5)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.Tensor.get_shape。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。