当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.experimental.dispatch_for_api用法及代码示例


覆盖 TensorFlow API 的默认实现的装饰器。

用法

tf.experimental.dispatch_for_api(
    api, *signatures
)

参数

  • api 要覆盖的 TensorFlow API。
  • *signatures 字典将参数名称或索引映射到类型注释,指定何时调用调度目标。特别是,如果任何签名匹配,将调用调度目标;如果所有指定参数的类型都与指定的类型注释匹配,则签名匹配。如果未指定签名,则将从调度目标函数的类型注释中读取签名。

返回

  • 覆盖 api 的默认实现的装饰器。

当使用与指定类型签名匹配的参数调用 API 时,装饰函数(称为 "dispatch target")将覆盖 API 的默认实现。使用将参数名称映射到类型注释的字典来指定签名。例如,在以下示例中,如果 xy 都是 MaskedTensor ,则将为 tf.add 调用 masked_add

class MaskedTensor(tf.experimental.ExtensionType):
  values:tf.Tensor
  mask:tf.Tensor
@dispatch_for_api(tf.math.add, {'x':MaskedTensor, 'y':MaskedTensor})
def masked_add(x, y, name=None):
  return MaskedTensor(x.values + y.values, x.mask & y.mask)
mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
values=[11 12], mask=[ True False]

如果指定了多个类型签名,则如果任何签名匹配,则将调用调度目标。例如,以下代码寄存器masked_add如果被调用x是一个MaskedTensor 或者 y是一个MaskedTensor.

@dispatch_for_api(tf.math.add, {'x':MaskedTensor}, {'y':MaskedTensor})
def masked_add(x, y):
  x_values = x.values if isinstance(x, MaskedTensor) else x
  x_mask = x.mask if isinstance(x, MaskedTensor) else True
  y_values = y.values if isinstance(y, MaskedTensor) else y
  y_mask = y.mask if isinstance(y, MaskedTensor) else True
  return MaskedTensor(x_values + y_values, x_mask & y_mask)

类型签名中的类型注释可以是类型对象(例如 MaskedTensor )、typing.List 值或 typing.Union 值。例如,如果valuesMaskedTensor 值的列表,则以下将注册要调用的masked_concat

@dispatch_for_api(tf.concat, {'values':typing.List[MaskedTensor]})
def masked_concat(values, axis):
  return MaskedTensor(tf.concat([v.values for v in values], axis),
                      tf.concat([v.mask for v in values], axis))

每个类型签名必须包含至少一个 tf.CompositeTensor 的子类(包括 tf.ExtensionType 的子类),并且只有在至少一个 type-annotated 参数包含 CompositeTensor 值时才会触发调度。此规则避免在退化情况下调用调度,例如以下示例:

  • @dispatch_for_api(tf.concat, {'values':List[MaskedTensor]}) :当用户调用 tf.concat([]) 时,不会分派到修饰的分派目标。

  • @dispatch_for_api(tf.add, {'x':Union[MaskedTensor, Tensor], 'y': Union[MaskedTensor, Tensor]}):当用户调用时不会分派到装饰的分派目标tf.add(tf.constant(1), tf.constant(2)).

调度目标的签名必须与被覆盖的 API 的签名相匹配。特别是,参数必须具有相同的名称,并且必须以相同的顺序出现。调度目标可以选择性地省略 "name" 参数,在这种情况下,它将在适当的时候通过对 tf.name_scope 的调用来包装。

注册的 API

@dispatch_for_api 可能覆盖的 TensorFlow API 是:

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.experimental.dispatch_for_api。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。