当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.Graph.gradient_override_map用法及代码示例


用法

@tf_contextlib.contextmanager
gradient_override_map(
    op_type_map
)

参数

  • op_type_map 将操作类型字符串映射到替代操作类型字符串的字典。

返回

  • 一个上下文管理器,它设置用于在该上下文中创建的一个或多个操作的替代操作类型。

抛出

  • TypeError 如果 op_type_map 不是将字符串映射到字符串的字典。

实验:用于覆盖梯度函数的上下文管理器。

此上下文管理器可用于覆盖将用于上下文范围内操作的梯度函数。

例如:

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  # ...

with tf.Graph().as_default() as g:
  c = tf.constant(5.0)
  s_1 = tf.square(c)  # Uses the default gradient for tf.square.
  with g.gradient_override_map({"Square":"CustomSquare"}):
    s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
                          # gradient of s_2.

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.Graph.gradient_override_map。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。