从给定的 keras 模型构造一个 Estimator 实例。
用法
tf.compat.v1.keras.estimator.model_to_estimator(
keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None,
config=None, checkpoint_format='saver', metric_names_map=None,
export_outputs=None
)参数
-
keras_model已编译的 Keras 模型对象。此参数与keras_model_path互斥。 Estimator 的model_fn使用模型的结构来克隆模型。默认为None。 -
keras_model_path以 HDF5 格式保存在磁盘上的已编译 Keras 模型的路径,可以使用 Keras 模型的save()方法生成。此参数与keras_model互斥。默认为None。 -
custom_objects用于克隆自定义对象的字典。这与不属于此 pip 包的类一起使用。例如,如果用户维护一个继承自tf.keras.layers.Layer的relu6类,则传递custom_objects={'relu6':relu6}。默认为None。 -
model_dir用于保存Estimator模型参数、图形、TensorBoard 的摘要文件等的目录。如果未设置,将使用tempfile.mkdtemp创建目录 -
configRunConfig配置Estimator.允许在model_fn基于配置,例如num_ps_replicas, 或者model_dir.默认为None.如果两者config.model_dir和model_dir参数(上面)被指定为model_dir参数优先。 -
checkpoint_format设置训练时估计器保存的检查点的格式。可能是saver或checkpoint,具体取决于是否保存来自tf.train.Saver或tf.train.Checkpoint的检查点。此参数当前默认为saver。当 2.0 发布时,默认值为checkpoint。估计器使用基于名称的tf.train.Saver检查点,而 Keras 模型使用来自tf.train.Checkpoint的基于对象的检查点。目前,仅函数模型和顺序模型支持从model_to_estimator保存基于对象的检查点。默认为'saver'。 -
metric_names_map可选字典将 Keras 模型输出指标名称映射到自定义名称。这可用于覆盖多 IO 模型用例中的默认 Keras 模型输出指标名称,并为 Estimator 中的eval_metric_ops提供自定义名称。 Keras 模型度量名称可以使用model.metrics_names获得,不包括任何损失度量,例如总损失和输出损失。例如,如果您的 Keras 模型有两个输出out_1和out_2,具有mse损失和acc指标,那么model.metrics_names将是['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']。不包括损失指标的模型指标名称将是['out_1_acc', 'out_2_acc']。 -
export_outputs可选字典。这可用于覆盖多 IO 模型用例中的默认 Keras 模型输出导出,并为export_outputs在tf.estimator.EstimatorSpec.默认为无,相当于 {'serving_default':tf.estimator.export.PredictOutput}。如果不是 None,则键必须与model.output_names.一个字典{name:output}其中:- 名称:此输出的任意名称。
- 输出:
ExportOutput类,例如ClassificationOutput,RegressionOutput或PredictOutput。 Single-headed 模型只需要在这个字典中指定一个条目。 Multi-headed 模型应为每个磁头指定一个条目,其中一个必须使用tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY命名 如果未提供条目,则将创建默认的PredictOutput映射到predictions。
返回
- 来自给定 keras 模型的 Estimator。
抛出
-
ValueError如果既没有给出 keras_model 也没有给出 keras_model_path。 -
ValueError如果同时给出 keras_model 和 keras_model_path。 -
ValueError如果keras_model_path 是 GCS URI。 -
ValueError如果 keras_model 尚未编译。 -
ValueError如果给出了无效的checkpoint_format。
如果您使用依赖于 Estimator 的基础架构或其他工具,您仍然可以构建 Keras 模型并使用 model_to_estimator 将 Keras 模型转换为 Estimator 以用于下游系统。
有关使用示例,请参阅:从 Keras 模型创建估算器。
样品重量:
model_to_estimator 返回的估计器被配置为可以处理样本权重(类似于 keras_model.fit(x, y, sample_weights) )。
要在训练或评估 Estimator 时传递样本权重,输入函数返回的第一项应该是带有键 features 和 sample_weights 的字典。下面的例子:
keras_model = tf.keras.Model(...)
keras_model.compile(...)
estimator = tf.keras.estimator.model_to_estimator(keras_model)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
带有自定义导出签名的示例:
inputs = {'a':tf.keras.Input(..., name='a'),
'b':tf.keras.Input(..., name='b')}
outputs = {'c':tf.keras.layers.Dense(..., name='c')(inputs['a']),
'd':tf.keras.layers.Dense(..., name='d')(inputs['b'])}
keras_model = tf.keras.Model(inputs, outputs)
keras_model.compile(...)
export_outputs = {'c':tf.estimator.export.RegressionOutput,
'd':tf.estimator.export.ClassificationOutput}
estimator = tf.keras.estimator.model_to_estimator(
keras_model, export_outputs=export_outputs)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
相关用法
- Python tf.compat.v1.keras.experimental.export_saved_model用法及代码示例
- Python tf.compat.v1.keras.experimental.load_from_saved_model用法及代码示例
- Python tf.compat.v1.keras.initializers.Ones.from_config用法及代码示例
- Python tf.compat.v1.keras.layers.DenseFeatures用法及代码示例
- Python tf.compat.v1.keras.initializers.Zeros.from_config用法及代码示例
- Python tf.compat.v1.keras.utils.track_tf1_style_variables用法及代码示例
- Python tf.compat.v1.keras.initializers.Ones用法及代码示例
- Python tf.compat.v1.keras.initializers.RandomNormal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.glorot_uniform.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.lecun_uniform用法及代码示例
- Python tf.compat.v1.keras.initializers.he_normal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.Orthogonal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.lecun_normal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.TruncatedNormal用法及代码示例
- Python tf.compat.v1.keras.initializers.RandomNormal用法及代码示例
- Python tf.compat.v1.keras.layers.enable_v2_dtype_behavior用法及代码示例
- Python tf.compat.v1.keras.initializers.he_uniform用法及代码示例
- Python tf.compat.v1.keras.initializers.Identity.from_config用法及代码示例
- Python tf.compat.v1.keras.callbacks.TensorBoard用法及代码示例
- Python tf.compat.v1.keras.initializers.Constant用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.keras.estimator.model_to_estimator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
