當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.recompute_grad用法及代碼示例


將函數定義為磁帶 auto-diff 的重新計算檢查點。

用法

tf.recompute_grad(
    f
)

參數

  • f 函數 f(*x) 返回 TensorTensor 輸出序列。

返回

  • 包裝 f 的函數 g 定義了自定義漸變,它在漸變調用的反向傳遞中重新計算 f

磁帶檢查點是一種減少 auto-diff 磁帶內存消耗的技術:

  • 在沒有磁帶檢查點操作的情況下,中間值被記錄到磁帶中以用於反向傳遞。

  • 使用磁帶檢查點,隻記錄函數調用及其輸入。在back-propagation 期間,recompute_grad 自定義漸變 (tf.custom_gradient) 在本地化 Tape 對象下重新計算函數。這種在反向傳播期間對函數的重新計算會執行冗餘計算,但會減少磁帶的整體內存使用量。

y = tf.Variable(1.0)
def my_function(x):
  tf.print('running')
  z = x*y
  return z
my_function_recompute = tf.recompute_grad(my_function)
with tf.GradientTape() as tape:
  r = tf.constant(1.0)
  for i in range(4):
    r = my_function_recompute(r)
running
running
running
running
grad = tape.gradient(r, [y])
running
running
running
running

如果沒有 recompute_grad ,磁帶將包含所有間歇步驟,並且不會執行重新計算。

with tf.GradientTape() as tape:
  r = tf.constant(1.0)
  for i in range(4):
    r = my_function(r)
running
running
running
running
grad = tape.gradient(r, [y])

如果 ftf.keras ModelLayer 對象,則在返回的函數 g 上不提供 f.variables 等方法和屬性。或者保留 f 的引用,或者使用 g.__wrapped__ 來訪問這些變量和方法。

def print_running_and_return(x):
  tf.print("running")
  return x
model = tf.keras.Sequential([
  tf.keras.layers.Lambda(print_running_and_return),
  tf.keras.layers.Dense(2)
])
model_recompute = tf.recompute_grad(model)
with tf.GradientTape(persistent=True) as tape:
  r = tf.constant([[1,2]])
  for i in range(4):
    r = model_recompute(r)
running
running
running
running
grad = tape.gradient(r, model.variables)
running
running
running
running

或者,使用__wrapped__ 屬性訪問原始模型對象。

grad = tape.gradient(r, model_recompute.__wrapped__.variables)
running
running
running
running

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.recompute_grad。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。