更新張量的形狀並在運行時檢查該形狀是否保持不變。
用法
tf.ensure_shape(
x, shape, name=None
)
參數
-
x
一個Tensor
。 -
shape
一個TensorShape
表示這個張量的形狀,一個TensorShapeProto
,一個列表,一個元組,或無。 -
name
此操作的名稱(可選)。默認為"EnsureShape"。
返回
-
一個
Tensor
。具有與x
相同的類型和內容。
拋出
-
tf.errors.InvalidArgumentError
如果
shape
與x
的形狀不兼容。
執行時,此操作斷言輸入張量 x
的形狀與 shape
參數兼容。有關詳細信息,請參閱tf.TensorShape.is_compatible_with
。
x = tf.constant([[1, 2, 3],
[4, 5, 6]])
x = tf.ensure_shape(x, [2, 3])
對未知尺寸使用None
:
x = tf.ensure_shape(x, [None, 3])
x = tf.ensure_shape(x, [2, None])
如果張量的形狀與 shape
參數不兼容,則會引發錯誤:
x = tf.ensure_shape(x, [5])
Traceback (most recent call last):
tf.errors.InvalidArgumentError:Shape of tensor dummy_input [3] is not
compatible with expected shape [5]. [Op:EnsureShape]
在圖形構建期間(通常跟蹤 tf.function
),tf.ensure_shape
通過合並兩個形狀來更新結果張量的 static-shape。有關詳細信息,請參閱 tf.TensorShape.merge_with。
當您知道 TensorFlow 無法靜態確定的形狀時,這非常有用。
以下簡單的 tf.function
在應用 ensure_shape
之前和之後打印輸入張量的 static-shape。
@tf.function
def f(tensor):
print("Static-shape before:", tensor.shape)
tensor = tf.ensure_shape(tensor, [None, 3])
print("Static-shape after:", tensor.shape)
return tensor
這可以讓您在跟蹤函數時看到tf.ensure_shape
的效果:
>>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
Static-shape before:(None, None)
Static-shape after:(None, 3)
cf(tf.zeros([3, 3])) # Passes
cf(tf.constant([1, 2, 3])) # fails
Traceback (most recent call last):
InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
上麵的示例引發了 tf.errors.InvalidArgumentError
,因為 x
的形狀 (3,)
與 shape
參數 (None, 3)
不兼容
在 tf.function
或 v1.Graph
上下文中,它會檢查構建時間和運行時形狀。這比僅檢查構建時間形狀的 tf.Tensor.set_shape
更嚴格。
注意:這與tf.Tensor.set_shape
的不同之處在於它設置了結果張量的靜態形狀並在運行時強製執行,如果張量的運行時形狀與指定的形狀不兼容,則會引發錯誤。 tf.Tensor.set_shape
設置張量的靜態形狀而不在運行時強製執行,這可能導致張量的statically-known 形狀與張量的運行時值不一致。
例如,加載已知大小的圖像:
@tf.function
def decode_image(png):
image = tf.image.decode_png(png, channels=3)
# the `print` executes during tracing.
print("Initial shape:", image.shape)
image = tf.ensure_shape(image,[28, 28, 3])
print("Final shape:", image.shape)
return image
跟蹤函數時,沒有執行任何操作,形狀可能是未知的。有關詳細信息,請參閱具體函數指南。
concrete_decode = decode_image.get_concrete_function(
tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape: (28, 28, 3)
image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
image = tf.cast(image,tf.uint8)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
print(image2.shape)
(28, 28, 3)
image = tf.concat([image,image], axis=0)
print(image.shape)
(56, 28, 3)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
Traceback (most recent call last):
tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not
compatible with expected shape [28,28,3].
警告:如果您不使用tf.ensure_shape
的結果,則檢查可能不會運行。
@tf.function
def bad_decode_image(png):
image = tf.image.decode_png(png, channels=3)
# the `print` executes during tracing.
print("Initial shape:", image.shape)
# BAD:forgot to use the returned tensor.
tf.ensure_shape(image,[28, 28, 3])
print("Final shape:", image.shape)
return image
image = bad_decode_image(png)
Initial shape: (None, None, 3)
Final shape: (None, None, 3)
print(image.shape)
(56, 28, 3)
相關用法
- Python tf.experimental.dlpack.from_dlpack用法及代碼示例
- Python tf.errors.InvalidArgumentError用法及代碼示例
- Python tf.experimental.numpy.iinfo用法及代碼示例
- Python tf.estimator.TrainSpec用法及代碼示例
- Python tf.experimental.Optional.has_value用法及代碼示例
- Python tf.estimator.LogisticRegressionHead用法及代碼示例
- Python tf.experimental.dispatch_for_unary_elementwise_apis用法及代碼示例
- Python tf.experimental.dispatch_for_api用法及代碼示例
- Python tf.estimator.MultiHead用法及代碼示例
- Python tf.experimental.unregister_dispatch_for用法及代碼示例
- Python tf.edit_distance用法及代碼示例
- Python tf.estimator.PoissonRegressionHead用法及代碼示例
- Python tf.estimator.WarmStartSettings用法及代碼示例
- Python tf.experimental.tensorrt.Converter用法及代碼示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代碼示例
- Python tf.estimator.RunConfig用法及代碼示例
- Python tf.experimental.ExtensionType用法及代碼示例
- Python tf.estimator.MultiLabelHead用法及代碼示例
- Python tf.experimental.Optional.get_value用法及代碼示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.ensure_shape。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。