返回并创建(如有必要)全局步长张量。
用法
tf.compat.v1.train.get_or_create_global_step(
graph=None
)
参数
-
graph
在其中创建全局步长张量的图。如果缺少,请使用默认图表。
返回
- 全局步长张量。
迁移到 TF2
警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。
随着全局图的弃用,TF 不再跟踪集合中的变量。也就是说,TF2 中没有全局变量。因此,全局阶跃函数已被删除 (get_or_create_global_step
, create_global_step
, get_global_step
)。您有两种迁移选择:
- 创建一个 Keras 优化器,它会生成一个
iterations
变量。调用apply_gradients
时,此变量会自动递增。 - 手动创建并增加
tf.Variable
。
下面是一个从使用全局步骤迁移到使用 Keras 优化器的示例:
定义一个虚拟模型和损失:
def compute_loss(x):
v = tf.Variable(3.0)
y = x * v
loss = x * 5 - x * v
return loss, [v]
迁移前:
g = tf.Graph()
with g.as_default():
x = tf.compat.v1.placeholder(tf.float32, [])
loss, var_list = compute_loss(x)
global_step = tf.compat.v1.train.get_or_create_global_step()
global_init = tf.compat.v1.global_variables_initializer()
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(loss, global_step, var_list)
sess = tf.compat.v1.Session(graph=g)
sess.run(global_init)
print("before training:", sess.run(global_step))
before training:0
sess.run(train_op, feed_dict={x:3})
print("after training:", sess.run(global_step))
after training:1
迁移到 Keras 优化器:
optimizer = tf.keras.optimizers.SGD(.01)
print("before training:", optimizer.iterations.numpy())
before training:0
with tf.GradientTape() as tape:
loss, var_list = compute_loss(3)
grads = tape.gradient(loss, var_list)
optimizer.apply_gradients(zip(grads, var_list))
print("after training:", optimizer.iterations.numpy())
after training:1
相关用法
- Python tf.compat.v1.train.get_global_step用法及代码示例
- Python tf.compat.v1.train.global_step用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.cosine_decay_restarts用法及代码示例
- Python tf.compat.v1.train.Optimizer用法及代码示例
- Python tf.compat.v1.train.AdagradOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.init_from_checkpoint用法及代码示例
- Python tf.compat.v1.train.Checkpoint用法及代码示例
- Python tf.compat.v1.train.Supervisor.managed_session用法及代码示例
- Python tf.compat.v1.train.Checkpoint.restore用法及代码示例
- Python tf.compat.v1.train.MonitoredSession.run_step_fn用法及代码示例
- Python tf.compat.v1.train.RMSPropOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.exponential_decay用法及代码示例
- Python tf.compat.v1.train.natural_exp_decay用法及代码示例
- Python tf.compat.v1.train.MomentumOptimizer用法及代码示例
- Python tf.compat.v1.train.RMSPropOptimizer用法及代码示例
- Python tf.compat.v1.train.GradientDescentOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.linear_cosine_decay用法及代码示例
- Python tf.compat.v1.train.Supervisor用法及代码示例
- Python tf.compat.v1.train.AdagradDAOptimizer.compute_gradients用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.get_or_create_global_step。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。