從給定的 keras 模型構造一個 Estimator
實例。
用法
tf.compat.v1.keras.estimator.model_to_estimator(
keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None,
config=None, checkpoint_format='saver', metric_names_map=None,
export_outputs=None
)
參數
-
keras_model
已編譯的 Keras 模型對象。此參數與keras_model_path
互斥。 Estimator 的model_fn
使用模型的結構來克隆模型。默認為None
。 -
keras_model_path
以 HDF5 格式保存在磁盤上的已編譯 Keras 模型的路徑,可以使用 Keras 模型的save()
方法生成。此參數與keras_model
互斥。默認為None
。 -
custom_objects
用於克隆自定義對象的字典。這與不屬於此 pip 包的類一起使用。例如,如果用戶維護一個繼承自tf.keras.layers.Layer
的relu6
類,則傳遞custom_objects={'relu6':relu6}
。默認為None
。 -
model_dir
用於保存Estimator
模型參數、圖形、TensorBoard 的摘要文件等的目錄。如果未設置,將使用tempfile.mkdtemp
創建目錄 -
config
RunConfig
配置Estimator
.允許在model_fn
基於配置,例如num_ps_replicas
, 或者model_dir
.默認為None
.如果兩者config.model_dir
和model_dir
參數(上麵)被指定為model_dir
參數優先。 -
checkpoint_format
設置訓練時估計器保存的檢查點的格式。可能是saver
或checkpoint
,具體取決於是否保存來自tf.train.Saver
或tf.train.Checkpoint
的檢查點。此參數當前默認為saver
。當 2.0 發布時,默認值為checkpoint
。估計器使用基於名稱的tf.train.Saver
檢查點,而 Keras 模型使用來自tf.train.Checkpoint
的基於對象的檢查點。目前,僅函數模型和順序模型支持從model_to_estimator
保存基於對象的檢查點。默認為'saver'。 -
metric_names_map
可選字典將 Keras 模型輸出指標名稱映射到自定義名稱。這可用於覆蓋多 IO 模型用例中的默認 Keras 模型輸出指標名稱,並為 Estimator 中的eval_metric_ops
提供自定義名稱。 Keras 模型度量名稱可以使用model.metrics_names
獲得,不包括任何損失度量,例如總損失和輸出損失。例如,如果您的 Keras 模型有兩個輸出out_1
和out_2
,具有mse
損失和acc
指標,那麽model.metrics_names
將是['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']
。不包括損失指標的模型指標名稱將是['out_1_acc', 'out_2_acc']
。 -
export_outputs
可選字典。這可用於覆蓋多 IO 模型用例中的默認 Keras 模型輸出導出,並為export_outputs
在tf.estimator.EstimatorSpec
.默認為無,相當於 {'serving_default':tf.estimator.export.PredictOutput
}。如果不是 None,則鍵必須與model.output_names
.一個字典{name:output}
其中:- 名稱:此輸出的任意名稱。
- 輸出:
ExportOutput
類,例如ClassificationOutput
,RegressionOutput
或PredictOutput
。 Single-headed 模型隻需要在這個字典中指定一個條目。 Multi-headed 模型應為每個磁頭指定一個條目,其中一個必須使用tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
命名 如果未提供條目,則將創建默認的PredictOutput
映射到predictions
。
返回
- 來自給定 keras 模型的 Estimator。
拋出
-
ValueError
如果既沒有給出 keras_model 也沒有給出 keras_model_path。 -
ValueError
如果同時給出 keras_model 和 keras_model_path。 -
ValueError
如果keras_model_path 是 GCS URI。 -
ValueError
如果 keras_model 尚未編譯。 -
ValueError
如果給出了無效的checkpoint_format。
如果您使用依賴於 Estimator 的基礎架構或其他工具,您仍然可以構建 Keras 模型並使用 model_to_estimator 將 Keras 模型轉換為 Estimator 以用於下遊係統。
有關使用示例,請參閱:從 Keras 模型創建估算器。
樣品重量:
model_to_estimator
返回的估計器被配置為可以處理樣本權重(類似於 keras_model.fit(x, y, sample_weights)
)。
要在訓練或評估 Estimator 時傳遞樣本權重,輸入函數返回的第一項應該是帶有鍵 features
和 sample_weights
的字典。下麵的例子:
keras_model = tf.keras.Model(...)
keras_model.compile(...)
estimator = tf.keras.estimator.model_to_estimator(keras_model)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
帶有自定義導出簽名的示例:
inputs = {'a':tf.keras.Input(..., name='a'),
'b':tf.keras.Input(..., name='b')}
outputs = {'c':tf.keras.layers.Dense(..., name='c')(inputs['a']),
'd':tf.keras.layers.Dense(..., name='d')(inputs['b'])}
keras_model = tf.keras.Model(inputs, outputs)
keras_model.compile(...)
export_outputs = {'c':tf.estimator.export.RegressionOutput,
'd':tf.estimator.export.ClassificationOutput}
estimator = tf.keras.estimator.model_to_estimator(
keras_model, export_outputs=export_outputs)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
相關用法
- Python tf.compat.v1.keras.experimental.export_saved_model用法及代碼示例
- Python tf.compat.v1.keras.experimental.load_from_saved_model用法及代碼示例
- Python tf.compat.v1.keras.initializers.Ones.from_config用法及代碼示例
- Python tf.compat.v1.keras.layers.DenseFeatures用法及代碼示例
- Python tf.compat.v1.keras.initializers.Zeros.from_config用法及代碼示例
- Python tf.compat.v1.keras.utils.track_tf1_style_variables用法及代碼示例
- Python tf.compat.v1.keras.initializers.Ones用法及代碼示例
- Python tf.compat.v1.keras.initializers.RandomNormal.from_config用法及代碼示例
- Python tf.compat.v1.keras.initializers.glorot_uniform.from_config用法及代碼示例
- Python tf.compat.v1.keras.initializers.lecun_uniform用法及代碼示例
- Python tf.compat.v1.keras.initializers.he_normal.from_config用法及代碼示例
- Python tf.compat.v1.keras.initializers.Orthogonal.from_config用法及代碼示例
- Python tf.compat.v1.keras.initializers.lecun_normal.from_config用法及代碼示例
- Python tf.compat.v1.keras.initializers.TruncatedNormal用法及代碼示例
- Python tf.compat.v1.keras.initializers.RandomNormal用法及代碼示例
- Python tf.compat.v1.keras.layers.enable_v2_dtype_behavior用法及代碼示例
- Python tf.compat.v1.keras.initializers.he_uniform用法及代碼示例
- Python tf.compat.v1.keras.initializers.Identity.from_config用法及代碼示例
- Python tf.compat.v1.keras.callbacks.TensorBoard用法及代碼示例
- Python tf.compat.v1.keras.initializers.Constant用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.keras.estimator.model_to_estimator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。