當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.experimental.BatchableExtensionType用法及代碼示例

可以成批處理和不成批處理的 ExtensionType。

繼承自:ExtensionType

用法

tf.experimental.BatchableExtensionType(
    *args, **kwargs
)

BatchableExtensionType 可以與需要批處理或取消批處理的 API 一起使用,包括 Kerastf.data.Datasettf.map_fn 。例如:

class Vehicle(BatchableExtensionType):
  top_speed:tf.Tensor
  mpg:tf.Tensor
batch = Vehicle([120, 150, 80], [30, 40, 12])
tf.map_fn(lambda vehicle:vehicle.top_speed * vehicle.mpg, batch,
          fn_output_signature=tf.int32).numpy()
array([3600, 6000,  960], dtype=int32)

這些 API 使用 ExtensionTypeBatchEncoder 來編碼 ExtensionType 值。默認編碼器假定值可以通過簡單地堆疊、取消堆疊或連接每個嵌套的 Tensor , ExtensionType , CompositeTensorTensorShape 字段來堆疊、取消堆疊或連接。不是這種情況的擴展類型將需要使用自定義 ExtensionTypeBatchEncoder 覆蓋 __batch_encoder__ 。有關詳細信息,請參閱tf.experimental.ExtensionTypeBatchEncoder

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.experimental.BatchableExtensionType。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。