当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.distribute.ReplicaContext用法及代码示例


具有可在副本上下文中调用的 API 集合的类。

用法

tf.distribute.ReplicaContext(
    strategy, replica_id_in_sync_group
)

参数

  • strategy 一个tf.distribute.Strategy
  • replica_id_in_sync_group 整数,Tensor 或无。尽可能首选整数以避免嵌套 tf.function 的问题。它接受 Tensor 只是为了与 tpu.replicate 兼容。

属性

  • devices 以字符串元组的形式返回要执行此副本的设备。 (已弃用)

    警告:此函数已弃用。它将在未来的版本中删除。更新说明:请避免依赖设备属性。

    注意:对于 tf.distribute.MirroredStrategytf.distribute.experimental.MultiWorkerMirroredStrategy ,这将返回设备字符串的嵌套列表,例如 [["GPU:0"]]。

  • num_replicas_in_sync 返回保持同步的副本数。
  • replica_id_in_sync_group 返回副本的 id。

    这标识了所有保持同步的副本中的副本。副本 id 的值范围可以从 0 到 tf.distribute.ReplicaContext.num_replicas_in_sync - 1。

    注意:这不能保证与用于低级别操作(例如 collective_permute)的 XLA 副本 ID 相同。

  • strategy 当前的tf.distribute.Strategy 对象。

您可以使用 tf.distribute.get_replica_context 获取 ReplicaContext 的实例,该实例只能在传递给 tf.distribute.Strategy.run 的函数内部调用。

strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
def func():
  replica_context = tf.distribute.get_replica_context()
  return replica_context.replica_id_in_sync_group
strategy.run(func)
PerReplica:{
  0:<tf.Tensor:shape=(), dtype=int32, numpy=0>,
  1:<tf.Tensor:shape=(), dtype=int32, numpy=1>
}

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.ReplicaContext。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。