當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.tpu.batch_parallel用法及代碼示例


沿批處理維度的分片computation 用於並行執行。

用法

tf.compat.v1.tpu.batch_parallel(
    computation:Callable[..., Any],
    inputs:Optional[List[List[Optional[core_types.Tensor]]]] = None,
    num_shards:int = 1,
    infeed_queue:Optional[tpu_feed.InfeedQueue] = None,
    device_assignment:Optional[tf.tpu.experimental.DeviceAssignment] = None,
    name:Optional[Text] = None,
    xla_options:Optional[tf.tpu.XLAOptions] = None
)

參數

  • computation 一個 Python 函數,它構建一個計算以應用於輸入的每個分片。
  • inputs 輸入張量列表或無(相當於一個空列表)。每個張量的第 0 維必須具有可被 num_shards 整除的大小。
  • num_shards 分片的數量。
  • infeed_queue 如果不是 None ,則從 InfeedQueue 將參數元組作為輸入附加到 computation
  • device_assignment 如果不是 None ,則使用 DeviceAssignment 說明計算中的邏輯核心與 TPU 拓撲中的物理核心之間的映射。如果 None ,則使用默認設備分配。 DeviceAssignment 如果計算的每個分片隻使用一個核,或者隻有一個分片,或者分片數等於 TPU 係統中的核數,則可以省略DeviceAssignment
  • name (已棄用)什麽都不做。
  • xla_options tpu.XLAOptions 的實例,指示傳遞給 XLA 編譯器的選項。使用 None 作為默認選項。

返回

  • 輸出張量列表。

拋出

  • ValueError 如果num_shards <= 0

shard() 周圍的便利包裝。

inputs 必須是張量列表或無(相當於一個空列表)。每個輸入沿第 0 維拆分為 num_shards 片段,計算並行應用於每個分片。

如果張量被 computation 詞匯捕獲,則張量將廣播到所有分片。例如:,

x = tf.constant(7) def computation(): return x + 3 ... = shard(computation, ...)

所有分片的輸出沿它們的第 0 維連接在一起。

計算的輸入和輸出必須至少是 rank-1 張量。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.tpu.batch_parallel。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。