當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.GradientTape.reset用法及代碼示例


用法

reset()

清除此磁帶中存儲的所有信息。

相當於使用新磁帶退出和重新進入磁帶上下文管理器。例如,以下兩個代碼塊是等價的:

with tf.GradientTape() as t:
  loss = loss_fn()
with tf.GradientTape() as t:
  loss += other_loss_fn()
t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn


# The following is equivalent to the above
with tf.GradientTape() as t:
  loss = loss_fn()
  t.reset()
  loss += other_loss_fn()
t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn

如果您不想退出磁帶的上下文管理器,或者因為所需的重置點位於控製流結構內而不能退出,這很有用:

with tf.GradientTape() as t:
  loss = ...
  if loss > k:
    t.reset()

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.GradientTape.reset。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。