當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.Graph.gradient_override_map用法及代碼示例


用法

@tf_contextlib.contextmanager
gradient_override_map(
    op_type_map
)

參數

  • op_type_map 將操作類型字符串映射到替代操作類型字符串的字典。

返回

  • 一個上下文管理器,它設置用於在該上下文中創建的一個或多個操作的替代操作類型。

拋出

  • TypeError 如果 op_type_map 不是將字符串映射到字符串的字典。

實驗:用於覆蓋梯度函數的上下文管理器。

此上下文管理器可用於覆蓋將用於上下文範圍內操作的梯度函數。

例如:

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  # ...

with tf.Graph().as_default() as g:
  c = tf.constant(5.0)
  s_1 = tf.square(c)  # Uses the default gradient for tf.square.
  with g.gradient_override_map({"Square":"CustomSquare"}):
    s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
                          # gradient of s_2.

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.Graph.gradient_override_map。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。