当前位置: 首页>>代码示例>>Python>>正文


Python tpu.CrossShardOptimizer方法代码示例

本文整理汇总了Python中tensorflow.contrib.tpu.CrossShardOptimizer方法的典型用法代码示例。如果您正苦于以下问题:Python tpu.CrossShardOptimizer方法的具体用法?Python tpu.CrossShardOptimizer怎么用?Python tpu.CrossShardOptimizer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.tpu的用法示例。


在下文中一共展示了tpu.CrossShardOptimizer方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_cross_shard_optimizer

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import CrossShardOptimizer [as 别名]
def get_cross_shard_optimizer(optimizer, disable_for_cpu_debugging=False):
  if disable_for_cpu_debugging:
    return optimizer
  return contrib_tpu.CrossShardOptimizer(optimizer) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:6,代码来源:tpu_model_wrapper.py

示例2: get_train_op

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import CrossShardOptimizer [as 别名]
def get_train_op(self, loss,  # pylint: disable=missing-docstring
                   var_list=None,
                   add_reg_loss=True,
                   use_tpu=False):

    if add_reg_loss:
      l2_loss = tf.reduce_sum(tf.losses.get_regularization_losses())
      loss += l2_loss

    optimizer = FLAGS.optimizer
    if optimizer == 'sgd':
      optimizer = tf.train.MomentumOptimizer(learning_rate=self.lr,
                                             momentum=0.9)
    elif optimizer == 'adam':
      optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
    else:
      raise ValueError('Unknown optimizer: %s' % optimizer)

    if use_tpu:
      # Wrap optimizer in CrossShardOptimizer which takes care of
      # synchronizing the weight updates between TPU cores.
      optimizer = CrossShardOptimizer(optimizer)

    opt_step = optimizer.minimize(loss, var_list=var_list,
                                  colocate_gradients_with_ops=True)

    if self.update_batchnorm_params:
      opt_step = tf.group([opt_step] +
                          tf.get_collection(tf.GraphKeys.UPDATE_OPS))

    opt_step = tf.group([opt_step, self.global_step_inc])

    return opt_step 
开发者ID:google-research,项目名称:s4l,代码行数:35,代码来源:trainer.py


注:本文中的tensorflow.contrib.tpu.CrossShardOptimizer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。