当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.recompute_grad用法及代码示例


将函数定义为磁带 auto-diff 的重新计算检查点。

用法

tf.recompute_grad(
    f
)

参数

  • f 函数 f(*x) 返回 TensorTensor 输出序列。

返回

  • 包装 f 的函数 g 定义了自定义渐变,它在渐变调用的反向传递中重新计算 f

磁带检查点是一种减少 auto-diff 磁带内存消耗的技术:

  • 在没有磁带检查点操作的情况下,中间值被记录到磁带中以用于反向传递。

  • 使用磁带检查点,只记录函数调用及其输入。在back-propagation 期间,recompute_grad 自定义渐变 (tf.custom_gradient) 在本地化 Tape 对象下重新计算函数。这种在反向传播期间对函数的重新计算会执行冗余计算,但会减少磁带的整体内存使用量。

y = tf.Variable(1.0)
def my_function(x):
  tf.print('running')
  z = x*y
  return z
my_function_recompute = tf.recompute_grad(my_function)
with tf.GradientTape() as tape:
  r = tf.constant(1.0)
  for i in range(4):
    r = my_function_recompute(r)
running
running
running
running
grad = tape.gradient(r, [y])
running
running
running
running

如果没有 recompute_grad ,磁带将包含所有间歇步骤,并且不会执行重新计算。

with tf.GradientTape() as tape:
  r = tf.constant(1.0)
  for i in range(4):
    r = my_function(r)
running
running
running
running
grad = tape.gradient(r, [y])

如果 ftf.keras ModelLayer 对象,则在返回的函数 g 上不提供 f.variables 等方法和属性。或者保留 f 的引用,或者使用 g.__wrapped__ 来访问这些变量和方法。

def print_running_and_return(x):
  tf.print("running")
  return x
model = tf.keras.Sequential([
  tf.keras.layers.Lambda(print_running_and_return),
  tf.keras.layers.Dense(2)
])
model_recompute = tf.recompute_grad(model)
with tf.GradientTape(persistent=True) as tape:
  r = tf.constant([[1,2]])
  for i in range(4):
    r = model_recompute(r)
running
running
running
running
grad = tape.gradient(r, model.variables)
running
running
running
running

或者,使用__wrapped__ 属性访问原始模型对象。

grad = tape.gradient(r, model_recompute.__wrapped__.variables)
running
running
running
running

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.recompute_grad。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。