将函数定义为磁带 auto-diff 的重新计算检查点。
用法
tf.recompute_grad(
f
)
参数
-
f
函数f(*x)
返回Tensor
或Tensor
输出序列。
返回
-
包装
f
的函数g
定义了自定义渐变,它在渐变调用的反向传递中重新计算f
。
磁带检查点是一种减少 auto-diff 磁带内存消耗的技术:
在没有磁带检查点操作的情况下,中间值被记录到磁带中以用于反向传递。
使用磁带检查点,只记录函数调用及其输入。在back-propagation 期间,
recompute_grad
自定义渐变 (tf.custom_gradient
) 在本地化 Tape 对象下重新计算函数。这种在反向传播期间对函数的重新计算会执行冗余计算,但会减少磁带的整体内存使用量。
y = tf.Variable(1.0)
def my_function(x):
tf.print('running')
z = x*y
return z
my_function_recompute = tf.recompute_grad(my_function)
with tf.GradientTape() as tape:
r = tf.constant(1.0)
for i in range(4):
r = my_function_recompute(r)
running
running
running
running
grad = tape.gradient(r, [y])
running
running
running
running
如果没有 recompute_grad
,磁带将包含所有间歇步骤,并且不会执行重新计算。
with tf.GradientTape() as tape:
r = tf.constant(1.0)
for i in range(4):
r = my_function(r)
running
running
running
running
grad = tape.gradient(r, [y])
如果 f
是 tf.keras
Model
或 Layer
对象,则在返回的函数 g
上不提供 f.variables
等方法和属性。或者保留 f
的引用,或者使用 g.__wrapped__
来访问这些变量和方法。
def print_running_and_return(x):
tf.print("running")
return x
model = tf.keras.Sequential([
tf.keras.layers.Lambda(print_running_and_return),
tf.keras.layers.Dense(2)
])
model_recompute = tf.recompute_grad(model)
with tf.GradientTape(persistent=True) as tape:
r = tf.constant([[1,2]])
for i in range(4):
r = model_recompute(r)
running
running
running
running
grad = tape.gradient(r, model.variables)
running
running
running
running
或者,使用__wrapped__
属性访问原始模型对象。
grad = tape.gradient(r, model_recompute.__wrapped__.variables)
running
running
running
running
相关用法
- Python tf.reverse用法及代码示例
- Python tf.register_tensor_conversion_function用法及代码示例
- Python tf.reshape用法及代码示例
- Python tf.reverse_sequence用法及代码示例
- Python tf.repeat用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.raw_ops.SelfAdjointEigV2用法及代码示例
- Python tf.raw_ops.BatchMatMul用法及代码示例
- Python tf.raw_ops.OneHot用法及代码示例
- Python tf.raw_ops.ResourceScatterNdSub用法及代码示例
- Python tf.raw_ops.ReadVariableXlaSplitND用法及代码示例
- Python tf.raw_ops.GatherV2用法及代码示例
- Python tf.raw_ops.Expm1用法及代码示例
- Python tf.range用法及代码示例
- Python tf.raw_ops.BitwiseAnd用法及代码示例
- Python tf.raw_ops.UniqueWithCounts用法及代码示例
- Python tf.raw_ops.DecodeGif用法及代码示例
- Python tf.random.truncated_normal用法及代码示例
- Python tf.raw_ops.Size用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.recompute_grad。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。