用法
dequeue(
name:Optional[Text] = None
)
参数
-
name
底层操作的名称。
返回
-
张量的嵌套结构,结构与
feature_config
相同
抛出
-
RuntimeError
如果在TPUStrategy
下未创建对象或未构建对象时调用(通过手动调用 build 或调用 enqueue)。
获取嵌入结果。
返回 tf.Tensor
对象的嵌套结构,将 feature_config
参数的结构与 TPUEmbedding
类匹配。张量的输出形状是 (*output_shape, dim)
, dim
是对应的 TableConfig
的维度。对于output_shape,可以设置三个位置。
- FeatureConfig 在 init 函数中提供。
- Per_replica_output_shapes通过初始化tpu嵌入类后直接调用build方法。
- 从输入特征的形状自动检测。这些地方的优先级是完全相同的顺序。
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
def training_step():
def tpu_step(tpu_features):
with tf.GradientTape() as tape:
activations = embedding.dequeue()
tape.watch(activations)
loss = ... # some computation involving activations
embedding_gradients = tape.gradient(loss, activations)
embedding.apply_gradients(embedding_gradients)
embedding_features, tpu_features = next(dataset_iterator)
embedding.enqueue(embedding_features, training=True)
strategy.run(tpu_step, args=(tpu_features, ))
training_step()
传递给 TPUEmbedding
对象的实例。
相关用法
- Python tf.tpu.experimental.embedding.TPUEmbedding.apply_gradients用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.enqueue用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding用法及代码示例
- Python tf.tpu.experimental.embedding.TableConfig用法及代码示例
- Python tf.tpu.experimental.embedding.FeatureConfig用法及代码示例
- Python tf.tpu.experimental.embedding.FTRL用法及代码示例
- Python tf.tpu.experimental.embedding.SGD用法及代码示例
- Python tf.tpu.experimental.embedding.Adam用法及代码示例
- Python tf.tpu.experimental.embedding.Adagrad用法及代码示例
- Python tf.tpu.experimental.embedding.serving_embedding_lookup用法及代码示例
- Python tf.tpu.experimental.DeviceAssignment用法及代码示例
- Python tf.types.experimental.GenericFunction.get_concrete_function用法及代码示例
- Python tf.train.Coordinator.stop_on_exception用法及代码示例
- Python tf.train.ExponentialMovingAverage用法及代码示例
- Python tf.train.Checkpoint.restore用法及代码示例
- Python tf.test.is_built_with_rocm用法及代码示例
- Python tf.train.Checkpoint.read用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.tpu.experimental.embedding.TPUEmbedding.dequeue。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。