用法
compute_gradients(
loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None,
colocate_gradients_with_ops=False, grad_loss=None
)
为 var_list
中的变量计算 loss
的梯度。
迁移到 TF2
警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。
TF2 中的tf.keras.optimizers.Optimizer
没有提供compute_gradients
方法,您应该使用tf.GradientTape
来获取渐变:
@tf.function
def train step(inputs):
batch_data, labels = inputs
with tf.GradientTape() as tape:
predictions = model(batch_data, training=True)
loss = tf.keras.losses.CategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.NONE)(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Args:损失:包含要最小化的值的张量或不带参数的可调用函数,它返回要最小化的值。启用即刻执行后,它必须是可调用的。 var_list:要更新以最小化 loss
的 tf.Variable
的可选列表或元组。默认为键 GraphKeys.TRAINABLE_VARIABLES
下图表中收集的变量列表。 gate_gradients:如何对梯度的计算进行门控。可以是 GATE_NONE
, GATE_OP
或 GATE_GRAPH
。 aggregation_method:指定用于组合梯度项的方法。有效值在类 AggregationMethod
中定义。 colocate_gradients_with_ops:如果为真,请尝试将渐变与相应的操作放在一起。 grad_loss:可选。一个 Tensor
保存为 loss
计算的梯度。
返回:(梯度,变量)对的列表。变量始终存在,但梯度可以是 None
。
引发:类型错误:如果 var_list
包含除 Variable
对象之外的任何其他内容。 ValueError:如果某些参数无效。 RuntimeError:如果在启用即刻执行的情况下调用并且loss
不可调用。
@compatibility(eager) 启用即刻执行时,gate_gradients
, aggregation_method
和 colocate_gradients_with_ops
将被忽略。
说明
这是 minimize()
的第一部分。它返回一个(梯度,变量)对列表,其中"gradient" 是"variable" 的梯度。请注意,如果给定变量没有梯度,"gradient" 可以是 Tensor
、IndexedSlices
或 None
。
相关用法
- Python tf.compat.v1.train.MomentumOptimizer用法及代码示例
- Python tf.compat.v1.train.MonitoredSession.run_step_fn用法及代码示例
- Python tf.compat.v1.train.MonitoredSession用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代码示例
- Python tf.compat.v1.train.cosine_decay_restarts用法及代码示例
- Python tf.compat.v1.train.Optimizer用法及代码示例
- Python tf.compat.v1.train.AdagradOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.init_from_checkpoint用法及代码示例
- Python tf.compat.v1.train.Checkpoint用法及代码示例
- Python tf.compat.v1.train.Supervisor.managed_session用法及代码示例
- Python tf.compat.v1.train.Checkpoint.restore用法及代码示例
- Python tf.compat.v1.train.global_step用法及代码示例
- Python tf.compat.v1.train.RMSPropOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.exponential_decay用法及代码示例
- Python tf.compat.v1.train.natural_exp_decay用法及代码示例
- Python tf.compat.v1.train.RMSPropOptimizer用法及代码示例
- Python tf.compat.v1.train.get_global_step用法及代码示例
- Python tf.compat.v1.train.GradientDescentOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.linear_cosine_decay用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.MomentumOptimizer.compute_gradients。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。