使用tf.tpu.experimental.embedding
配置应用标准查找操作。
用法
tf.tpu.experimental.embedding.serving_embedding_lookup(
inputs, weights, tables, feature_config
)
参数
-
inputs
张量、稀疏张量或不规则张量的嵌套结构。 -
weights
Tensors、SparseTensors 或 RaggedTensors 或 None 的嵌套结构,无权重。如果不是 None,则结构必须与输入的结构匹配,但条目允许为 None。 -
tables
将 TableConfig 对象映射到变量的字典。 -
feature_config
FeatureConfig 对象的嵌套结构,具有与输入相同的结构。
返回
- 与输入具有相同结构的张量嵌套结构。
此函数是一个实用程序,它允许将tf.tpu.experimental.embedding
配置对象与标准查找函数一起使用。这可以在导出使用 tf.tpu.experimental.embedding.TPUEmbedding
在 CPU 上服务的模型时使用。特别是 tf.tpu.experimental.embedding.TPUEmbedding
仅支持在 TPU 上查找,不应成为服务图的一部分。
请注意,配置对象中的 TPU 特定选项(例如 max_sequence_length
)将被忽略。
在以下示例中,我们采用经过训练的模型(有关上下文,请参阅tf.tpu.experimental.embedding.TPUEmbedding
的文档)并使用服务函数创建保存的模型,该函数将执行嵌入查找并将结果传递给您的模型:
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=1024,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)
@tf.function(input_signature=[{'feature_one':tf.TensorSpec(...),
'feature_two':tf.TensorSpec(...),
'feature_three':tf.TensorSpec(...)}])
def serve_tensors(embedding_features):
embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
embedding_features, None, embedding.embedding_tables,
feature_config)
return model(embedded_features)
model.embedding_api = embedding
tf.saved_model.save(model,
export_dir=...,
signatures={'serving_default':serve_tensors})
注意:将嵌入 api 对象分配给模型的成员很重要,因为 tf.saved_model.save
仅支持将变量保存为一个 Trackable
对象。由于模型的权重在 model
并且嵌入表由 embedding
管理,我们将 embedding
分配给 model
的属性,以便 tf.saved_model.save 可以找到嵌入变量。
注意:相同的serve_tensors
函数和tf.saved_model.save
调用将直接在训练中起作用。
相关用法
- Python tf.tpu.experimental.embedding.FeatureConfig用法及代码示例
- Python tf.tpu.experimental.embedding.FTRL用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.apply_gradients用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.dequeue用法及代码示例
- Python tf.tpu.experimental.embedding.SGD用法及代码示例
- Python tf.tpu.experimental.embedding.TableConfig用法及代码示例
- Python tf.tpu.experimental.embedding.Adam用法及代码示例
- Python tf.tpu.experimental.embedding.Adagrad用法及代码示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.enqueue用法及代码示例
- Python tf.tpu.experimental.DeviceAssignment用法及代码示例
- Python tf.types.experimental.GenericFunction.get_concrete_function用法及代码示例
- Python tf.train.Coordinator.stop_on_exception用法及代码示例
- Python tf.train.ExponentialMovingAverage用法及代码示例
- Python tf.train.Checkpoint.restore用法及代码示例
- Python tf.test.is_built_with_rocm用法及代码示例
- Python tf.train.Checkpoint.read用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.tpu.experimental.embedding.serving_embedding_lookup。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。