当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.experimental.BatchableExtensionType用法及代码示例


可以成批处理和不成批处理的 ExtensionType。

继承自:ExtensionType

用法

tf.experimental.BatchableExtensionType(
    *args, **kwargs
)

BatchableExtensionType 可以与需要批处理或取消批处理的 API 一起使用,包括 Kerastf.data.Datasettf.map_fn 。例如:

class Vehicle(BatchableExtensionType):
  top_speed:tf.Tensor
  mpg:tf.Tensor
batch = Vehicle([120, 150, 80], [30, 40, 12])
tf.map_fn(lambda vehicle:vehicle.top_speed * vehicle.mpg, batch,
          fn_output_signature=tf.int32).numpy()
array([3600, 6000,  960], dtype=int32)

这些 API 使用 ExtensionTypeBatchEncoder 来编码 ExtensionType 值。默认编码器假定值可以通过简单地堆叠、取消堆叠或连接每个嵌套的 Tensor , ExtensionType , CompositeTensorTensorShape 字段来堆叠、取消堆叠或连接。不是这种情况的扩展类型将需要使用自定义 ExtensionTypeBatchEncoder 覆盖 __batch_encoder__ 。有关详细信息,请参阅tf.experimental.ExtensionTypeBatchEncoder

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.experimental.BatchableExtensionType。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。