当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.tpu.experimental.embedding.serving_embedding_lookup用法及代码示例


使用tf.tpu.experimental.embedding 配置应用标准查找操作。

用法

tf.tpu.experimental.embedding.serving_embedding_lookup(
    inputs, weights, tables, feature_config
)

参数

  • inputs 张量、稀疏张量或不规则张量的嵌套结构。
  • weights Tensors、SparseTensors 或 RaggedTensors 或 None 的嵌套结构,无权重。如果不是 None,则结构必须与输入的结构匹配,但条目允许为 None。
  • tables 将 TableConfig 对象映射到变量的字典。
  • feature_config FeatureConfig 对象的嵌套结构,具有与输入相同的结构。

返回

  • 与输入具有相同结构的张量嵌套结构。

此函数是一个实用程序,它允许将tf.tpu.experimental.embedding 配置对象与标准查找函数一起使用。这可以在导出使用 tf.tpu.experimental.embedding.TPUEmbedding 在 CPU 上服务的模型时使用。特别是 tf.tpu.experimental.embedding.TPUEmbedding 仅支持在 TPU 上查找,不应成为服务图的一部分。

请注意,配置对象中的 TPU 特定选项(例如 max_sequence_length )将被忽略。

在以下示例中,我们采用经过训练的模型(有关上下文,请参阅tf.tpu.experimental.embedding.TPUEmbedding 的文档)并使用服务函数创建保存的模型,该函数将执行嵌入查找并将结果传递给您的模型:

model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
    feature_config=feature_config,
    batch_size=1024,
    optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)

@tf.function(input_signature=[{'feature_one':tf.TensorSpec(...),
                               'feature_two':tf.TensorSpec(...),
                               'feature_three':tf.TensorSpec(...)}])
def serve_tensors(embedding_features):
  embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
      embedding_features, None, embedding.embedding_tables,
      feature_config)
  return model(embedded_features)

model.embedding_api = embedding
tf.saved_model.save(model,
                    export_dir=...,
                    signatures={'serving_default':serve_tensors})

注意:将嵌入 api 对象分配给模型的成员很重要,因为 tf.saved_model.save 仅支持将变量保存为一个 Trackable 对象。由于模型的权重在 model 并且嵌入表由 embedding 管理,我们将 embedding 分配给 model 的属性,以便 tf.saved_model.save 可以找到嵌入变量。

注意:相同的serve_tensors 函数和tf.saved_model.save 调用将直接在训练中起作用。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.tpu.experimental.embedding.serving_embedding_lookup。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。