当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.distribute.StrategyExtended.update用法及代码示例


用法

update(
    var, fn, args=(), kwargs=None, group=True
)

参数

  • var 变量,可能镜像到多个设备,以进行操作。
  • fn 要调用的函数。应该将变量作为第一个参数。
  • args 元组或列表。要传递给 fn() 的其他位置参数。
  • kwargs 带有关键字参数的字典传递给 fn()
  • group 布尔值。默认为真。如果为 False,则返回值将被解包。

返回

  • 默认情况下,所有副本的合并返回值 fn。合并的结果具有依赖关系,以确保如果对其进行评估,副作用(更新)将发生在每个副本上。如果改为指定"group=False",则此函数将返回列表嵌套,其中每个列表的每个副本都有一个元素,调用者负责确保执行所有元素。

运行fn 以使用镜像到相同设备的输入更新var

tf.distribute.StrategyExtended.update 采用要更新的分布式变量 var、更新函数 fn 以及 argskwargs 用于 fn 。它将 fn 应用于 var 的每个组件变量,并从 argskwargs 传递相应的值。 argskwargs 都不能包含 per-replica 值。如果它们包含镜像值,它们将在调用 fn 之前被解包。例如,fn 可以是 assign_add 并且 args 可以是镜像 DistributedValues,其中每个组件都包含要添加到此镜像变量 var 的值。调用update 将在var 的每个组件变量上调用assign_add,并在该设备上使用相应的张量值。

示例用法:

strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
devices
with strategy.scope():
  v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
def update_fn(v):
  return v.assign(1.0)
result = strategy.extended.update(v, update_fn)
# result is
# Mirrored:{
#  0:tf.Tensor(1.0, shape=(), dtype=float32),
#  1:tf.Tensor(1.0, shape=(), dtype=float32)
# }

如果var跨多个设备镜像,则该方法实现逻辑如下:

results = {}
for device, v in var:
  with tf.device(device):
    # args and kwargs will be unwrapped if they are mirrored.
    results[device] = fn(v, *args, **kwargs)
return merged(results)

否则,此方法返回与 var 并置的 fn(var, *args, **kwargs)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.StrategyExtended.update。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。