当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.distribute.DistributedIterator.get_next_as_optional用法及代码示例


用法

get_next_as_optional()

返回

返回包含所有副本的下一个值的 tf.experimental.Optional

如果tf.distribute.DistributedIterator 已到达序列末尾,则返回的tf.experimental.Optional 将没有值。

示例用法:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
global_batch_size = 2
steps_per_loop = 2
dataset = tf.data.Dataset.range(10).batch(global_batch_size)
distributed_iterator = iter(
    strategy.experimental_distribute_dataset(dataset))
def step_fn(x):
  # train the model with inputs
  return x
@tf.function
def train_fn(distributed_iterator):
  for _ in tf.range(steps_per_loop):
    optional_data = distributed_iterator.get_next_as_optional()
    if not optional_data.has_value():
      break
    per_replica_results = strategy.run(step_fn, args=(optional_data.get_value(),))
    tf.print(strategy.experimental_local_results(per_replica_results))
train_fn(distributed_iterator)
# ([0 1], [2 3])
# ([4], [])

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.DistributedIterator.get_next_as_optional。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。