当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.distribute.MirroredStrategy.make_input_fn_iterator用法及代码示例


用法

make_input_fn_iterator(
    input_fn, replication_mode=tf.distribute.InputReplicationMode.PER_WORKER
)

参数

返回

  • 一个迭代器对象,首先应该是 .initialize() -ed。然后可以将其传递给 strategy.experimental_run() ,或者您可以 iterator.get_next() 获取下一个值以传递给 strategy.extended.call_for_each_replica()

返回从输入函数创建的副本之间拆分的迭代器。

已弃用:此方法在 TF 2.x 中不可用。

input_fn 应该采用 tf.distribute.InputContext 对象,可以访问有关批处理和输入分片的信息:

def input_fn(input_context):
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
  return d.shard(input_context.num_input_pipelines,
                 input_context.input_pipeline_id)
with strategy.scope():
  iterator = strategy.make_input_fn_iterator(input_fn)
  replica_results = strategy.experimental_run(replica_fn, iterator)

input_fn 返回的 tf.data.Dataset 应该具有 per-replica 批量大小,可以使用 input_context.get_per_replica_batch_size 计算。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.distribute.MirroredStrategy.make_input_fn_iterator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。