本文整理汇总了Python中tensorflow.contrib.tpu.python.tpu.tpu_optimizer.CrossShardOptimizer方法的典型用法代码示例。如果您正苦于以下问题:Python tpu_optimizer.CrossShardOptimizer方法的具体用法?Python tpu_optimizer.CrossShardOptimizer怎么用?Python tpu_optimizer.CrossShardOptimizer使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.tpu.python.tpu.tpu_optimizer
的用法示例。
在下文中一共展示了tpu_optimizer.CrossShardOptimizer方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_train_op
# 需要导入模块: from tensorflow.contrib.tpu.python.tpu import tpu_optimizer [as 别名]
# 或者: from tensorflow.contrib.tpu.python.tpu.tpu_optimizer import CrossShardOptimizer [as 别名]
def get_train_op(self, loss):
"""Creates a training op.
Args:
loss: A float32 `Tensor` representing the total training loss.
Returns:
train_op: A slim.learning.create_train_op train_op.
Raises:
ValueError: If specified optimizer isn't supported.
"""
# Get variables to train (defined in subclass).
assert self.variables_to_train
# Define a learning rate schedule.
decay_steps = self._config.learning.decay_steps
decay_factor = self._config.learning.decay_factor
learning_rate = float(self._config.learning.learning_rate)
# Define a learning rate schedule.
global_step = slim.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps,
decay_factor,
staircase=True)
# Create an optimizer.
opt_type = self._config.learning.optimizer
if opt_type == 'adam':
opt = tf.train.AdamOptimizer(learning_rate)
elif opt_type == 'momentum':
opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
elif opt_type == 'rmsprop':
opt = tf.train.RMSPropOptimizer(learning_rate, momentum=0.9,
epsilon=1.0, decay=0.9)
else:
raise ValueError('Unsupported optimizer %s' % opt_type)
if self._config.use_tpu:
opt = tpu_optimizer.CrossShardOptimizer(opt)
# Create a training op.
# train_op = opt.minimize(loss, var_list=self.variables_to_train)
# Create a training op.
train_op = slim.learning.create_train_op(
loss,
optimizer=opt,
variables_to_train=self.variables_to_train,
update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))
return train_op