當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.distribute.DistributedIterator.get_next用法及代碼示例


用法

get_next()

返回

拋出

從迭代器返回所有副本的下一個輸入。

示例使用:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.range(100).batch(2)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
dist_dataset_iterator = iter(dist_dataset)
@tf.function
def one_step(input):
  return input
step_num = 5
for _ in range(step_num):
  strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
strategy.experimental_local_results(dist_dataset_iterator.get_next())
(<tf.Tensor:shape=(1,), dtype=int64, numpy=array([10])>,
 <tf.Tensor:shape=(1,), dtype=int64, numpy=array([11])>)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.distribute.DistributedIterator.get_next。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。