當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.distribute.OneDeviceStrategy.distribute_datasets_from_function用法及代碼示例


用法

distribute_datasets_from_function(
    dataset_fn, options=None
)

參數

返回

  • 一個“分布式Dataset”,調用者可以像常規數據集一樣對其進行迭代。

分發由對 dataset_fn 的調用創建的 tf.data.Dataset 實例。

dataset_fn 將為策略中的每個工作人員調用一次。在這種情況下,我們隻有一個工人和一個設備,所以 dataset_fn 被調用一次。

dataset_fn 應該采用 tf.distribute.InputContext 實例,其中可以訪問有關批處理和輸入複製的信息:

def dataset_fn(input_context):
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
  return d.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)

inputs = strategy.distribute_datasets_from_function(dataset_fn)

for batch in inputs:
  replica_results = strategy.run(replica_fn, args=(batch,))

重要的:dataset_fn 返回的 tf.data.Dataset 應該具有 per-replica 批量大小,這與使用全局批量大小的 experimental_distribute_dataset 不同。這可以使用 input_context.get_per_replica_batch_size 來計算。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.distribute.OneDeviceStrategy.distribute_datasets_from_function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。