当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.experimental.dispatch_for_unary_elementwise_apis用法及代码示例


装饰器覆盖一元元素 API 的默认实现。

用法

tf.experimental.dispatch_for_unary_elementwise_apis(
    x_type
)

参数

  • x_type 指示何时应调用 api 处理程序的类型注释。有关受支持的注释类型的列表,请参阅dispatch_for_api

返回

  • 一个装饰师。

只要第一个参数(通常命名为 x )的值与类型注释 x_type 匹配,装饰函数(称为“元素 api 处理程序”)就会覆盖任何一元元素 API 的默认实现。 elementwise api 处理程序使用两个参数调用:

elementwise_api_handler(api_func, x)

其中 api_func 是一个接受单个参数并执行元素操作的函数(例如 tf.abs ),而 x 是元素 api 的第一个参数。

以下示例显示了如何使用此装饰器更新所有一元元素操作以处理 MaskedTensor 类型:

class MaskedTensor(tf.experimental.ExtensionType):
  values:tf.Tensor
  mask:tf.Tensor
@dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_api_handler(api_func, x):
  return MaskedTensor(api_func(x.values), x.mask)
mt = MaskedTensor([1, -2, -3], [True, False, True])
abs_mt = tf.abs(mt)
print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
values=[1 2 3], mask=[ True False True]

对于接受超出 x 的额外参数的一元元素操作,这些参数不会传递给元素 api 处理程序,而是在调用 api_func 时自动添加。例如,在以下示例中,dtype 参数未传递给 unary_elementwise_api_handler ,而是由 api_func 添加。

ones_mt = tf.ones_like(mt, dtype=tf.float32)
print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
values=[1.0 1.0 1.0], mask=[ True False True]

注册的 API

一元元素 API 是:

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.experimental.dispatch_for_unary_elementwise_apis。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。