从给定的 keras 模型构造一个 Estimator
实例。
用法
tf.keras.estimator.model_to_estimator(
keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None,
config=None, checkpoint_format='checkpoint', metric_names_map=None,
export_outputs=None
)
参数
-
keras_model
已编译的 Keras 模型对象。此参数与keras_model_path
互斥。 Estimator 的model_fn
使用模型的结构来克隆模型。默认为None
。 -
keras_model_path
以 HDF5 格式保存在磁盘上的已编译 Keras 模型的路径,可以使用 Keras 模型的save()
方法生成。此参数与keras_model
互斥。默认为None
。 -
custom_objects
用于克隆自定义对象的字典。这与不属于此 pip 包的类一起使用。例如,如果用户维护一个继承自tf.keras.layers.Layer
的relu6
类,则传递custom_objects={'relu6':relu6}
。默认为None
。 -
model_dir
用于保存Estimator
模型参数、图形、TensorBoard 的摘要文件等的目录。如果未设置,将使用tempfile.mkdtemp
创建目录 -
config
RunConfig
配置Estimator
.允许在model_fn
基于配置,例如num_ps_replicas
, 或者model_dir
.默认为None
.如果两者config.model_dir
和model_dir
参数(上面)被指定为model_dir
参数优先。 -
checkpoint_format
设置训练时估计器保存的检查点的格式。可能是saver
或checkpoint
,具体取决于是否保存来自tf.compat.v1.train.Saver
或tf.train.Checkpoint
的检查点。默认值为checkpoint
。估计器使用基于名称的tf.train.Saver
检查点,而 Keras 模型使用来自tf.train.Checkpoint
的基于对象的检查点。目前,仅函数模型和顺序模型支持从model_to_estimator
保存基于对象的检查点。默认为'checkpoint'。 -
metric_names_map
可选字典将 Keras 模型输出指标名称映射到自定义名称。这可用于覆盖多 IO 模型用例中的默认 Keras 模型输出指标名称,并为 Estimator 中的eval_metric_ops
提供自定义名称。 Keras 模型度量名称可以使用model.metrics_names
获得,不包括任何损失度量,例如总损失和输出损失。例如,如果您的 Keras 模型有两个输出out_1
和out_2
,具有mse
损失和acc
指标,那么model.metrics_names
将是['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']
。不包括损失指标的模型指标名称将是['out_1_acc', 'out_2_acc']
。 -
export_outputs
可选字典。这可用于覆盖多 IO 模型用例中的默认 Keras 模型输出导出,并为export_outputs
在tf.estimator.EstimatorSpec
.默认为无,相当于 {'serving_default':tf.estimator.export.PredictOutput
}。如果不是 None,则键必须与model.output_names
.一个字典{name:output}
其中:- 名称:此输出的任意名称。
- 输出:
ExportOutput
类,例如ClassificationOutput
,RegressionOutput
或PredictOutput
。 Single-headed 模型只需要在这个字典中指定一个条目。 Multi-headed 模型应为每个磁头指定一个条目,其中一个必须使用tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
命名 如果未提供条目,则将创建默认的PredictOutput
映射到predictions
。
返回
- 来自给定 keras 模型的 Estimator。
抛出
-
ValueError
如果既没有给出 keras_model 也没有给出 keras_model_path。 -
ValueError
如果同时给出 keras_model 和 keras_model_path。 -
ValueError
如果keras_model_path 是 GCS URI。 -
ValueError
如果 keras_model 尚未编译。 -
ValueError
如果给出了无效的checkpoint_format。
如果您使用依赖于 Estimator 的基础架构或其他工具,您仍然可以构建 Keras 模型并使用 model_to_estimator 将 Keras 模型转换为 Estimator 以用于下游系统。
有关使用示例,请参阅:从 Keras 模型创建估算器。
样品重量:
model_to_estimator
返回的估计器被配置为可以处理样本权重(类似于 keras_model.fit(x, y, sample_weights)
)。
要在训练或评估 Estimator 时传递样本权重,输入函数返回的第一项应该是带有键 features
和 sample_weights
的字典。下面的例子:
keras_model = tf.keras.Model(...)
keras_model.compile(...)
estimator = tf.keras.estimator.model_to_estimator(keras_model)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
带有自定义导出签名的示例:
inputs = {'a':tf.keras.Input(..., name='a'),
'b':tf.keras.Input(..., name='b')}
outputs = {'c':tf.keras.layers.Dense(..., name='c')(inputs['a']),
'd':tf.keras.layers.Dense(..., name='d')(inputs['b'])}
keras_model = tf.keras.Model(inputs, outputs)
keras_model.compile(...)
export_outputs = {'c':tf.estimator.export.RegressionOutput,
'd':tf.estimator.export.ClassificationOutput}
estimator = tf.keras.estimator.model_to_estimator(
keras_model, export_outputs=export_outputs)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
注意:我们不支持在 Keras 中创建加权指标并使用 model_to_estimator
在 Estimator API 中将它们转换为加权指标。您必须使用 add_metrics
函数直接在估算器规范上创建这些指标。
要自定义估计器eval_metric_ops
名称,您可以传入 metric_names_map
字典,将 keras 模型输出指标名称映射到自定义名称,如下所示:
input_a = tf.keras.layers.Input(shape=(16,), name='input_a')
input_b = tf.keras.layers.Input(shape=(16,), name='input_b')
dense = tf.keras.layers.Dense(8, name='dense_1')
interm_a = dense(input_a)
interm_b = dense(input_b)
merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
output_a = tf.keras.layers.Dense(3, activation='softmax', name='dense_2')(
merged)
output_b = tf.keras.layers.Dense(2, activation='softmax', name='dense_3')(
merged)
keras_model = tf.keras.models.Model(
inputs=[input_a, input_b], outputs=[output_a, output_b])
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
metrics={
'dense_2':'categorical_accuracy',
'dense_3':'categorical_accuracy'
})
metric_names_map = {
'dense_2_categorical_accuracy':'acc_1',
'dense_3_categorical_accuracy':'acc_2',
}
keras_est = tf.keras.estimator.model_to_estimator(
keras_model=keras_model,
config=config,
metric_names_map=metric_names_map)
相关用法
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.experimental.SequenceFeatures用法及代码示例
- Python tf.keras.experimental.LinearModel.save用法及代码示例
- Python tf.keras.experimental.LinearModel.compile用法及代码示例
- Python tf.keras.experimental.LinearModel.save_spec用法及代码示例
- Python tf.keras.experimental.WideDeepModel用法及代码示例
- Python tf.keras.experimental.PeepholeLSTMCell用法及代码示例
- Python tf.keras.experimental.LinearModel.compute_loss用法及代码示例
- Python tf.keras.experimental.LinearModel用法及代码示例
- Python tf.keras.experimental.WideDeepModel.reset_metrics用法及代码示例
- Python tf.keras.experimental.WideDeepModel.save_spec用法及代码示例
- Python tf.keras.experimental.LinearModel.reset_metrics用法及代码示例
- Python tf.keras.experimental.WideDeepModel.save用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_metrics用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compile用法及代码示例
- Python tf.keras.experimental.LinearModel.compute_metrics用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.estimator.model_to_estimator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。