將函數定義為磁帶 auto-diff 的重新計算檢查點。
用法
tf.recompute_grad(
f
)
參數
-
f
函數f(*x)
返回Tensor
或Tensor
輸出序列。
返回
-
包裝
f
的函數g
定義了自定義漸變,它在漸變調用的反向傳遞中重新計算f
。
磁帶檢查點是一種減少 auto-diff 磁帶內存消耗的技術:
在沒有磁帶檢查點操作的情況下,中間值被記錄到磁帶中以用於反向傳遞。
使用磁帶檢查點,隻記錄函數調用及其輸入。在back-propagation 期間,
recompute_grad
自定義漸變 (tf.custom_gradient
) 在本地化 Tape 對象下重新計算函數。這種在反向傳播期間對函數的重新計算會執行冗餘計算,但會減少磁帶的整體內存使用量。
y = tf.Variable(1.0)
def my_function(x):
tf.print('running')
z = x*y
return z
my_function_recompute = tf.recompute_grad(my_function)
with tf.GradientTape() as tape:
r = tf.constant(1.0)
for i in range(4):
r = my_function_recompute(r)
running
running
running
running
grad = tape.gradient(r, [y])
running
running
running
running
如果沒有 recompute_grad
,磁帶將包含所有間歇步驟,並且不會執行重新計算。
with tf.GradientTape() as tape:
r = tf.constant(1.0)
for i in range(4):
r = my_function(r)
running
running
running
running
grad = tape.gradient(r, [y])
如果 f
是 tf.keras
Model
或 Layer
對象,則在返回的函數 g
上不提供 f.variables
等方法和屬性。或者保留 f
的引用,或者使用 g.__wrapped__
來訪問這些變量和方法。
def print_running_and_return(x):
tf.print("running")
return x
model = tf.keras.Sequential([
tf.keras.layers.Lambda(print_running_and_return),
tf.keras.layers.Dense(2)
])
model_recompute = tf.recompute_grad(model)
with tf.GradientTape(persistent=True) as tape:
r = tf.constant([[1,2]])
for i in range(4):
r = model_recompute(r)
running
running
running
running
grad = tape.gradient(r, model.variables)
running
running
running
running
或者,使用__wrapped__
屬性訪問原始模型對象。
grad = tape.gradient(r, model_recompute.__wrapped__.variables)
running
running
running
running
相關用法
- Python tf.reverse用法及代碼示例
- Python tf.register_tensor_conversion_function用法及代碼示例
- Python tf.reshape用法及代碼示例
- Python tf.reverse_sequence用法及代碼示例
- Python tf.repeat用法及代碼示例
- Python tf.raw_ops.TPUReplicatedInput用法及代碼示例
- Python tf.raw_ops.Bitcast用法及代碼示例
- Python tf.raw_ops.SelfAdjointEigV2用法及代碼示例
- Python tf.raw_ops.BatchMatMul用法及代碼示例
- Python tf.raw_ops.OneHot用法及代碼示例
- Python tf.raw_ops.ResourceScatterNdSub用法及代碼示例
- Python tf.raw_ops.ReadVariableXlaSplitND用法及代碼示例
- Python tf.raw_ops.GatherV2用法及代碼示例
- Python tf.raw_ops.Expm1用法及代碼示例
- Python tf.range用法及代碼示例
- Python tf.raw_ops.BitwiseAnd用法及代碼示例
- Python tf.raw_ops.UniqueWithCounts用法及代碼示例
- Python tf.raw_ops.DecodeGif用法及代碼示例
- Python tf.random.truncated_normal用法及代碼示例
- Python tf.raw_ops.Size用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.recompute_grad。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。