用法
enqueue(
features,
weights=None,
training:bool = True,
name:Optional[Text] = None,
device:Optional[Text] = None
)
参数
-
features
tf.Tensor
s、tf.SparseTensor
s 或tf.RaggedTensor
s 的嵌套结构,与feature_config
具有相同的结构。输入将向下转换为tf.int32
。每次调用仅支持tf.SparseTensor
或tf.RaggedTensor
中的一种类型。 -
weights
如果不是None
,则为tf.Tensor
s、tf.SparseTensor
s 或tf.RaggedTensor
s 的嵌套结构,与上述匹配,但张量应为浮点类型(它们将向下转换为tf.float32
)。对于tf.SparseTensor
,我们假设indices
对于来自features
的并行条目是相同的,同样对于tf.RaggedTensor
,我们假设row_splits是相同的。 -
training
默认为True
。如果False
,将批次作为推理批次排入队列(仅限前向传递)。当这是False
时不要调用apply_gradients
,因为这可能会导致死锁。 name:底层操作的名称。设备:该批次应入队的设备名称(例如'/task:0/device:TPU:2')。当且仅当 features 不是tf.distribute.DistributedValues
并且未在 TPU 上下文中调用 enqueue(例如,在TPUStrategy.run
中)时,才应设置此项。
抛出
-
ValueError
在 strategy.run 调用中调用时,输入不是直接取自strategy.run
调用的参数。此外,如果features
中任何序列的大小与feature_config
中的相应序列不匹配。同样对于weights
,如果不是None
。如果特征的输入形状与之前的调用不相等或不同。 -
RuntimeError
在 strategy.run 调用和 XLA 控制流内部调用时。如果无法确定batch_size 并且未调用构建。 -
TypeError
如果features
中任何序列的类型与feature_config
中的对应序列不匹配。同样对于weights
,如果不是None
。
将 id 张量排入队列以进行嵌入查找。
此函数将要在嵌入表中查找的特征结构排入队列。我们期望特征中每个张量的输入形状与通过 FeatureConfig 或构建方法(如果有)设置的输出形状相匹配。输出形状将根据带有max_sequence_length 的输入形状或 FeatureConfig 中的输出形状设置自动检测。请注意,输出形状基于每个副本的批量大小。如果您的输入数据集被批处理为全局批处理大小并且您使用 tf.distribute.TPUStrategy
的 experimental_distribute_dataset
或者如果您使用 distribute_datasets_from_function
并批处理到由传递给输入函数的上下文计算的每个核心批处理大小,则输出形状应该自动匹配。
自动检测到输出形状:
- 对于稠密张量,如果 rank 2 或更高,请确保张量的最后一维为 1。输出形状将是不包括最后一维的输入形状。
- 对于稀疏张量,请确保张量的秩为 2 及以上。一种。如果特征配置有 max_sequence_length 等于 0 或输出形状集(max_sequence_length 设置将被忽略),输出形状将是不包括最后一个维度的输入形状。湾。否则,如果张量为 2 级,则输出形状将是输入形状,最后一维设置为 max_sequence_length。如果张量高于 2 阶,则输出形状将是输入形状,不包括最后一个维度,并且输出形状的最后一个维度将设置为 max_sequence_length。
- 对于参差不齐的张量,确保张量的秩为 2。如果特征配置有 max_sequence_length 等于 0 或输出形状集(max_sequence_length 设置将被忽略),输出形状将是不包括最后一个维度的输入形状。湾。否则,输出形状将是不包括最后一个维度的输入形状,并且输出形状的最后一个维度将设置为max_sequence_length。
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
def training_step():
def tpu_step(tpu_features):
with tf.GradientTape() as tape:
activations = embedding.dequeue()
tape.watch(activations)
loss = ... # some computation involving activations
embedding_gradients = tape.gradient(loss, activations)
embedding.apply_gradients(embedding_gradients)
embedding_features, tpu_features = next(dataset_iterator)
embedding.enqueue(embedding_features, training=True)
strategy.run(tpu_step, args=(tpu_features,))
training_step()
注意:如上所述使用embedding.apply_gradients
时应指定training=True
,不使用embedding.apply_gradients
时应指定training=False
(例如,用于冻结嵌入或进行评估时)。
对于更细粒度的控制,在上面的示例中,该行
embedding.enqueue(embedding_features, training=True)
可以替换为
per_core_embedding_features = self.strategy.experimental_local_results(
embedding_features)
def per_core_enqueue(ctx):
core_id = ctx.replica_id_in_sync_group
device = strategy.extended.worker_devices[core_id]
embedding.enqueue(per_core_embedding_features[core_id],
device=device)
strategy.experimental_distribute_values_from_function(
per_core_queue_inputs)
相关用法
- Python tf.tpu.experimental.embedding.TPUEmbedding.apply_gradients用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.dequeue用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding用法及代码示例
- Python tf.tpu.experimental.embedding.TableConfig用法及代码示例
- Python tf.tpu.experimental.embedding.FeatureConfig用法及代码示例
- Python tf.tpu.experimental.embedding.FTRL用法及代码示例
- Python tf.tpu.experimental.embedding.SGD用法及代码示例
- Python tf.tpu.experimental.embedding.Adam用法及代码示例
- Python tf.tpu.experimental.embedding.Adagrad用法及代码示例
- Python tf.tpu.experimental.embedding.serving_embedding_lookup用法及代码示例
- Python tf.tpu.experimental.DeviceAssignment用法及代码示例
- Python tf.types.experimental.GenericFunction.get_concrete_function用法及代码示例
- Python tf.train.Coordinator.stop_on_exception用法及代码示例
- Python tf.train.ExponentialMovingAverage用法及代码示例
- Python tf.train.Checkpoint.restore用法及代码示例
- Python tf.test.is_built_with_rocm用法及代码示例
- Python tf.train.Checkpoint.read用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.tpu.experimental.embedding.TPUEmbedding.enqueue。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。