用法
enqueue(
features,
weights=None,
training:bool = True,
name:Optional[Text] = None,
device:Optional[Text] = None
)
參數
-
features
tf.Tensor
s、tf.SparseTensor
s 或tf.RaggedTensor
s 的嵌套結構,與feature_config
具有相同的結構。輸入將向下轉換為tf.int32
。每次調用僅支持tf.SparseTensor
或tf.RaggedTensor
中的一種類型。 -
weights
如果不是None
,則為tf.Tensor
s、tf.SparseTensor
s 或tf.RaggedTensor
s 的嵌套結構,與上述匹配,但張量應為浮點類型(它們將向下轉換為tf.float32
)。對於tf.SparseTensor
,我們假設indices
對於來自features
的並行條目是相同的,同樣對於tf.RaggedTensor
,我們假設row_splits是相同的。 -
training
默認為True
。如果False
,將批次作為推理批次排入隊列(僅限前向傳遞)。當這是False
時不要調用apply_gradients
,因為這可能會導致死鎖。 name:底層操作的名稱。設備:該批次應入隊的設備名稱(例如'/task:0/device:TPU:2')。當且僅當 features 不是tf.distribute.DistributedValues
並且未在 TPU 上下文中調用 enqueue(例如,在TPUStrategy.run
中)時,才應設置此項。
拋出
-
ValueError
在 strategy.run 調用中調用時,輸入不是直接取自strategy.run
調用的參數。此外,如果features
中任何序列的大小與feature_config
中的相應序列不匹配。同樣對於weights
,如果不是None
。如果特征的輸入形狀與之前的調用不相等或不同。 -
RuntimeError
在 strategy.run 調用和 XLA 控製流內部調用時。如果無法確定batch_size 並且未調用構建。 -
TypeError
如果features
中任何序列的類型與feature_config
中的對應序列不匹配。同樣對於weights
,如果不是None
。
將 id 張量排入隊列以進行嵌入查找。
此函數將要在嵌入表中查找的特征結構排入隊列。我們期望特征中每個張量的輸入形狀與通過 FeatureConfig 或構建方法(如果有)設置的輸出形狀相匹配。輸出形狀將根據帶有max_sequence_length 的輸入形狀或 FeatureConfig 中的輸出形狀設置自動檢測。請注意,輸出形狀基於每個副本的批量大小。如果您的輸入數據集被批處理為全局批處理大小並且您使用 tf.distribute.TPUStrategy
的 experimental_distribute_dataset
或者如果您使用 distribute_datasets_from_function
並批處理到由傳遞給輸入函數的上下文計算的每個核心批處理大小,則輸出形狀應該自動匹配。
自動檢測到輸出形狀:
- 對於稠密張量,如果 rank 2 或更高,請確保張量的最後一維為 1。輸出形狀將是不包括最後一維的輸入形狀。
- 對於稀疏張量,請確保張量的秩為 2 及以上。一種。如果特征配置有 max_sequence_length 等於 0 或輸出形狀集(max_sequence_length 設置將被忽略),輸出形狀將是不包括最後一個維度的輸入形狀。灣。否則,如果張量為 2 級,則輸出形狀將是輸入形狀,最後一維設置為 max_sequence_length。如果張量高於 2 階,則輸出形狀將是輸入形狀,不包括最後一個維度,並且輸出形狀的最後一個維度將設置為 max_sequence_length。
- 對於參差不齊的張量,確保張量的秩為 2。如果特征配置有 max_sequence_length 等於 0 或輸出形狀集(max_sequence_length 設置將被忽略),輸出形狀將是不包括最後一個維度的輸入形狀。灣。否則,輸出形狀將是不包括最後一個維度的輸入形狀,並且輸出形狀的最後一個維度將設置為max_sequence_length。
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
def training_step():
def tpu_step(tpu_features):
with tf.GradientTape() as tape:
activations = embedding.dequeue()
tape.watch(activations)
loss = ... # some computation involving activations
embedding_gradients = tape.gradient(loss, activations)
embedding.apply_gradients(embedding_gradients)
embedding_features, tpu_features = next(dataset_iterator)
embedding.enqueue(embedding_features, training=True)
strategy.run(tpu_step, args=(tpu_features,))
training_step()
注意:如上所述使用embedding.apply_gradients
時應指定training=True
,不使用embedding.apply_gradients
時應指定training=False
(例如,用於凍結嵌入或進行評估時)。
對於更細粒度的控製,在上麵的示例中,該行
embedding.enqueue(embedding_features, training=True)
可以替換為
per_core_embedding_features = self.strategy.experimental_local_results(
embedding_features)
def per_core_enqueue(ctx):
core_id = ctx.replica_id_in_sync_group
device = strategy.extended.worker_devices[core_id]
embedding.enqueue(per_core_embedding_features[core_id],
device=device)
strategy.experimental_distribute_values_from_function(
per_core_queue_inputs)
相關用法
- Python tf.tpu.experimental.embedding.TPUEmbedding.apply_gradients用法及代碼示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.dequeue用法及代碼示例
- Python tf.tpu.experimental.embedding.TPUEmbedding用法及代碼示例
- Python tf.tpu.experimental.embedding.TableConfig用法及代碼示例
- Python tf.tpu.experimental.embedding.FeatureConfig用法及代碼示例
- Python tf.tpu.experimental.embedding.FTRL用法及代碼示例
- Python tf.tpu.experimental.embedding.SGD用法及代碼示例
- Python tf.tpu.experimental.embedding.Adam用法及代碼示例
- Python tf.tpu.experimental.embedding.Adagrad用法及代碼示例
- Python tf.tpu.experimental.embedding.serving_embedding_lookup用法及代碼示例
- Python tf.tpu.experimental.DeviceAssignment用法及代碼示例
- Python tf.types.experimental.GenericFunction.get_concrete_function用法及代碼示例
- Python tf.train.Coordinator.stop_on_exception用法及代碼示例
- Python tf.train.ExponentialMovingAverage用法及代碼示例
- Python tf.train.Checkpoint.restore用法及代碼示例
- Python tf.test.is_built_with_rocm用法及代碼示例
- Python tf.train.Checkpoint.read用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.tpu.experimental.embedding.TPUEmbedding.enqueue。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。