當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.distribute.StrategyExtended.update用法及代碼示例


用法

update(
    var, fn, args=(), kwargs=None, group=True
)

參數

  • var 變量,可能鏡像到多個設備,以進行操作。
  • fn 要調用的函數。應該將變量作為第一個參數。
  • args 元組或列表。要傳遞給 fn() 的其他位置參數。
  • kwargs 帶有關鍵字參數的字典傳遞給 fn()
  • group 布爾值。默認為真。如果為 False,則返回值將被解包。

返回

  • 默認情況下,所有副本的合並返回值 fn。合並的結果具有依賴關係,以確保如果對其進行評估,副作用(更新)將發生在每個副本上。如果改為指定"group=False",則此函數將返回列表嵌套,其中每個列表的每個副本都有一個元素,調用者負責確保執行所有元素。

運行fn 以使用鏡像到相同設備的輸入更新var

tf.distribute.StrategyExtended.update 采用要更新的分布式變量 var、更新函數 fn 以及 argskwargs 用於 fn 。它將 fn 應用於 var 的每個組件變量,並從 argskwargs 傳遞相應的值。 argskwargs 都不能包含 per-replica 值。如果它們包含鏡像值,它們將在調用 fn 之前被解包。例如,fn 可以是 assign_add 並且 args 可以是鏡像 DistributedValues,其中每個組件都包含要添加到此鏡像變量 var 的值。調用update 將在var 的每個組件變量上調用assign_add,並在該設備上使用相應的張量值。

示例用法:

strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
devices
with strategy.scope():
  v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
def update_fn(v):
  return v.assign(1.0)
result = strategy.extended.update(v, update_fn)
# result is
# Mirrored:{
#  0:tf.Tensor(1.0, shape=(), dtype=float32),
#  1:tf.Tensor(1.0, shape=(), dtype=float32)
# }

如果var跨多個設備鏡像,則該方法實現邏輯如下:

results = {}
for device, v in var:
  with tf.device(device):
    # args and kwargs will be unwrapped if they are mirrored.
    results[device] = fn(v, *args, **kwargs)
return merged(results)

否則,此方法返回與 var 並置的 fn(var, *args, **kwargs)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.distribute.StrategyExtended.update。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。