當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.distribute.MirroredStrategy.make_input_fn_iterator用法及代碼示例


用法

make_input_fn_iterator(
    input_fn, replication_mode=tf.distribute.InputReplicationMode.PER_WORKER
)

參數

返回

  • 一個迭代器對象,首先應該是 .initialize() -ed。然後可以將其傳遞給 strategy.experimental_run() ,或者您可以 iterator.get_next() 獲取下一個值以傳遞給 strategy.extended.call_for_each_replica()

返回從輸入函數創建的副本之間拆分的迭代器。

已棄用:此方法在 TF 2.x 中不可用。

input_fn 應該采用 tf.distribute.InputContext 對象,可以訪問有關批處理和輸入分片的信息:

def input_fn(input_context):
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
  return d.shard(input_context.num_input_pipelines,
                 input_context.input_pipeline_id)
with strategy.scope():
  iterator = strategy.make_input_fn_iterator(input_fn)
  replica_results = strategy.experimental_run(replica_fn, iterator)

input_fn 返回的 tf.data.Dataset 應該具有 per-replica 批量大小,可以使用 input_context.get_per_replica_batch_size 計算。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.distribute.MirroredStrategy.make_input_fn_iterator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。