當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.experimental.dispatch_for_api用法及代碼示例

覆蓋 TensorFlow API 的默認實現的裝飾器。

用法

tf.experimental.dispatch_for_api(
    api, *signatures
)

參數

  • api 要覆蓋的 TensorFlow API。
  • *signatures 字典將參數名稱或索引映射到類型注釋,指定何時調用調度目標。特別是,如果任何簽名匹配,將調用調度目標;如果所有指定參數的類型都與指定的類型注釋匹配,則簽名匹配。如果未指定簽名,則將從調度目標函數的類型注釋中讀取簽名。

返回

  • 覆蓋 api 的默認實現的裝飾器。

當使用與指定類型簽名匹配的參數調用 API 時,裝飾函數(稱為 "dispatch target")將覆蓋 API 的默認實現。使用將參數名稱映射到類型注釋的字典來指定簽名。例如,在以下示例中,如果 xy 都是 MaskedTensor ,則將為 tf.add 調用 masked_add

class MaskedTensor(tf.experimental.ExtensionType):
  values:tf.Tensor
  mask:tf.Tensor
@dispatch_for_api(tf.math.add, {'x':MaskedTensor, 'y':MaskedTensor})
def masked_add(x, y, name=None):
  return MaskedTensor(x.values + y.values, x.mask & y.mask)
mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
values=[11 12], mask=[ True False]

如果指定了多個類型簽名,則如果任何簽名匹配,則將調用調度目標。例如,以下代碼寄存器masked_add如果被調用x是一個MaskedTensor 或者 y是一個MaskedTensor.

@dispatch_for_api(tf.math.add, {'x':MaskedTensor}, {'y':MaskedTensor})
def masked_add(x, y):
  x_values = x.values if isinstance(x, MaskedTensor) else x
  x_mask = x.mask if isinstance(x, MaskedTensor) else True
  y_values = y.values if isinstance(y, MaskedTensor) else y
  y_mask = y.mask if isinstance(y, MaskedTensor) else True
  return MaskedTensor(x_values + y_values, x_mask & y_mask)

類型簽名中的類型注釋可以是類型對象(例如 MaskedTensor )、typing.List 值或 typing.Union 值。例如,如果valuesMaskedTensor 值的列表,則以下將注冊要調用的masked_concat

@dispatch_for_api(tf.concat, {'values':typing.List[MaskedTensor]})
def masked_concat(values, axis):
  return MaskedTensor(tf.concat([v.values for v in values], axis),
                      tf.concat([v.mask for v in values], axis))

每個類型簽名必須包含至少一個 tf.CompositeTensor 的子類(包括 tf.ExtensionType 的子類),並且隻有在至少一個 type-annotated 參數包含 CompositeTensor 值時才會觸發調度。此規則避免在退化情況下調用調度,例如以下示例:

  • @dispatch_for_api(tf.concat, {'values':List[MaskedTensor]}) :當用戶調用 tf.concat([]) 時,不會分派到修飾的分派目標。

  • @dispatch_for_api(tf.add, {'x':Union[MaskedTensor, Tensor], 'y': Union[MaskedTensor, Tensor]}):當用戶調用時不會分派到裝飾的分派目標tf.add(tf.constant(1), tf.constant(2)).

調度目標的簽名必須與被覆蓋的 API 的簽名相匹配。特別是,參數必須具有相同的名稱,並且必須以相同的順序出現。調度目標可以選擇性地省略 "name" 參數,在這種情況下,它將在適當的時候通過對 tf.name_scope 的調用來包裝。

注冊的 API

@dispatch_for_api 可能覆蓋的 TensorFlow API 是:

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.experimental.dispatch_for_api。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。