当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.GradientTape.reset用法及代码示例


用法

reset()

清除此磁带中存储的所有信息。

相当于使用新磁带退出和重新进入磁带上下文管理器。例如,以下两个代码块是等价的:

with tf.GradientTape() as t:
  loss = loss_fn()
with tf.GradientTape() as t:
  loss += other_loss_fn()
t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn


# The following is equivalent to the above
with tf.GradientTape() as t:
  loss = loss_fn()
  t.reset()
  loss += other_loss_fn()
t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn

如果您不想退出磁带的上下文管理器,或者因为所需的重置点位于控制流结构内而不能退出,这很有用:

with tf.GradientTape() as t:
  loss = ...
  if loss > k:
    t.reset()

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.GradientTape.reset。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。