使用tf.tpu.experimental.embedding
配置應用標準查找操作。
用法
tf.tpu.experimental.embedding.serving_embedding_lookup(
inputs, weights, tables, feature_config
)
參數
-
inputs
張量、稀疏張量或不規則張量的嵌套結構。 -
weights
Tensors、SparseTensors 或 RaggedTensors 或 None 的嵌套結構,無權重。如果不是 None,則結構必須與輸入的結構匹配,但條目允許為 None。 -
tables
將 TableConfig 對象映射到變量的字典。 -
feature_config
FeatureConfig 對象的嵌套結構,具有與輸入相同的結構。
返回
- 與輸入具有相同結構的張量嵌套結構。
此函數是一個實用程序,它允許將tf.tpu.experimental.embedding
配置對象與標準查找函數一起使用。這可以在導出使用 tf.tpu.experimental.embedding.TPUEmbedding
在 CPU 上服務的模型時使用。特別是 tf.tpu.experimental.embedding.TPUEmbedding
僅支持在 TPU 上查找,不應成為服務圖的一部分。
請注意,配置對象中的 TPU 特定選項(例如 max_sequence_length
)將被忽略。
在以下示例中,我們采用經過訓練的模型(有關上下文,請參閱tf.tpu.experimental.embedding.TPUEmbedding
的文檔)並使用服務函數創建保存的模型,該函數將執行嵌入查找並將結果傳遞給您的模型:
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=1024,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)
@tf.function(input_signature=[{'feature_one':tf.TensorSpec(...),
'feature_two':tf.TensorSpec(...),
'feature_three':tf.TensorSpec(...)}])
def serve_tensors(embedding_features):
embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
embedding_features, None, embedding.embedding_tables,
feature_config)
return model(embedded_features)
model.embedding_api = embedding
tf.saved_model.save(model,
export_dir=...,
signatures={'serving_default':serve_tensors})
注意:將嵌入 api 對象分配給模型的成員很重要,因為 tf.saved_model.save
僅支持將變量保存為一個 Trackable
對象。由於模型的權重在 model
並且嵌入表由 embedding
管理,我們將 embedding
分配給 model
的屬性,以便 tf.saved_model.save 可以找到嵌入變量。
注意:相同的serve_tensors
函數和tf.saved_model.save
調用將直接在訓練中起作用。
相關用法
- Python tf.tpu.experimental.embedding.FeatureConfig用法及代碼示例
- Python tf.tpu.experimental.embedding.FTRL用法及代碼示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.apply_gradients用法及代碼示例
- Python tf.tpu.experimental.embedding.TPUEmbedding用法及代碼示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.dequeue用法及代碼示例
- Python tf.tpu.experimental.embedding.SGD用法及代碼示例
- Python tf.tpu.experimental.embedding.TableConfig用法及代碼示例
- Python tf.tpu.experimental.embedding.Adam用法及代碼示例
- Python tf.tpu.experimental.embedding.Adagrad用法及代碼示例
- Python tf.tpu.experimental.embedding.TPUEmbedding.enqueue用法及代碼示例
- Python tf.tpu.experimental.DeviceAssignment用法及代碼示例
- Python tf.types.experimental.GenericFunction.get_concrete_function用法及代碼示例
- Python tf.train.Coordinator.stop_on_exception用法及代碼示例
- Python tf.train.ExponentialMovingAverage用法及代碼示例
- Python tf.train.Checkpoint.restore用法及代碼示例
- Python tf.test.is_built_with_rocm用法及代碼示例
- Python tf.train.Checkpoint.read用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.tpu.experimental.embedding.serving_embedding_lookup。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。