本文整理匯總了Python中tensorflow.contrib.tpu.python.tpu.tpu_optimizer.CrossShardOptimizer方法的典型用法代碼示例。如果您正苦於以下問題:Python tpu_optimizer.CrossShardOptimizer方法的具體用法?Python tpu_optimizer.CrossShardOptimizer怎麽用?Python tpu_optimizer.CrossShardOptimizer使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.contrib.tpu.python.tpu.tpu_optimizer
的用法示例。
在下文中一共展示了tpu_optimizer.CrossShardOptimizer方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: get_train_op
# 需要導入模塊: from tensorflow.contrib.tpu.python.tpu import tpu_optimizer [as 別名]
# 或者: from tensorflow.contrib.tpu.python.tpu.tpu_optimizer import CrossShardOptimizer [as 別名]
def get_train_op(self, loss):
"""Creates a training op.
Args:
loss: A float32 `Tensor` representing the total training loss.
Returns:
train_op: A slim.learning.create_train_op train_op.
Raises:
ValueError: If specified optimizer isn't supported.
"""
# Get variables to train (defined in subclass).
assert self.variables_to_train
# Define a learning rate schedule.
decay_steps = self._config.learning.decay_steps
decay_factor = self._config.learning.decay_factor
learning_rate = float(self._config.learning.learning_rate)
# Define a learning rate schedule.
global_step = slim.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps,
decay_factor,
staircase=True)
# Create an optimizer.
opt_type = self._config.learning.optimizer
if opt_type == 'adam':
opt = tf.train.AdamOptimizer(learning_rate)
elif opt_type == 'momentum':
opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
elif opt_type == 'rmsprop':
opt = tf.train.RMSPropOptimizer(learning_rate, momentum=0.9,
epsilon=1.0, decay=0.9)
else:
raise ValueError('Unsupported optimizer %s' % opt_type)
if self._config.use_tpu:
opt = tpu_optimizer.CrossShardOptimizer(opt)
# Create a training op.
# train_op = opt.minimize(loss, var_list=self.variables_to_train)
# Create a training op.
train_op = slim.learning.create_train_op(
loss,
optimizer=opt,
variables_to_train=self.variables_to_train,
update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))
return train_op