當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.RegisterGradient用法及代碼示例

用於注冊操作類型的梯度函數的裝飾器。

用法

tf.RegisterGradient(
    op_type
)

參數

  • op_type 操作的字符串類型。這對應於定義操作的 proto 的 OpDef.name 字段。

拋出

  • TypeError 如果 op_type 不是字符串。

此裝飾器僅在定義新的操作類型時使用。對於具有 m 輸入和 n 輸出的操作,梯度函數是一個采用原始 Operationn Tensor 對象的函數(表示相對於操作的每個輸出的梯度),並且返回 m Tensor 對象(表示相對於操作的每個輸入的部分梯度)。

例如,假設 "Sub" 類型的操作采用兩個輸入 xy ,並返回單個輸出 x - y ,將注冊以下梯度函數:

@tf.RegisterGradient("Sub")
def _sub_grad(unused_op, grad):
  return grad, tf.negative(grad)

裝飾器參數op_type 是操作的字符串類型。這對應於定義操作的原型的OpDef.name 字段。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.RegisterGradient。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。