当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python tf.RegisterGradient用法及代码示例

用于注册操作类型的梯度函数的装饰器。

用法

tf.RegisterGradient(
    op_type
)

参数

  • op_type 操作的字符串类型。这对应于定义操作的 proto 的 OpDef.name 字段。

抛出

  • TypeError 如果 op_type 不是字符串。

此装饰器仅在定义新的操作类型时使用。对于具有 m 输入和 n 输出的操作,梯度函数是一个采用原始 Operationn Tensor 对象的函数(表示相对于操作的每个输出的梯度),并且返回 m Tensor 对象(表示相对于操作的每个输入的部分梯度)。

例如,假设 "Sub" 类型的操作采用两个输入 xy ,并返回单个输出 x - y ,将注册以下梯度函数:

@tf.RegisterGradient("Sub")
def _sub_grad(unused_op, grad):
  return grad, tf.negative(grad)

装饰器参数op_type 是操作的字符串类型。这对应于定义操作的原型的OpDef.name 字段。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.RegisterGradient。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。