从给定的 keras 模型构造一个 Estimator
实例。
用法
tf.compat.v1.keras.estimator.model_to_estimator(
keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None,
config=None, checkpoint_format='saver', metric_names_map=None,
export_outputs=None
)
参数
-
keras_model
已编译的 Keras 模型对象。此参数与keras_model_path
互斥。 Estimator 的model_fn
使用模型的结构来克隆模型。默认为None
。 -
keras_model_path
以 HDF5 格式保存在磁盘上的已编译 Keras 模型的路径,可以使用 Keras 模型的save()
方法生成。此参数与keras_model
互斥。默认为None
。 -
custom_objects
用于克隆自定义对象的字典。这与不属于此 pip 包的类一起使用。例如,如果用户维护一个继承自tf.keras.layers.Layer
的relu6
类,则传递custom_objects={'relu6':relu6}
。默认为None
。 -
model_dir
用于保存Estimator
模型参数、图形、TensorBoard 的摘要文件等的目录。如果未设置,将使用tempfile.mkdtemp
创建目录 -
config
RunConfig
配置Estimator
.允许在model_fn
基于配置,例如num_ps_replicas
, 或者model_dir
.默认为None
.如果两者config.model_dir
和model_dir
参数(上面)被指定为model_dir
参数优先。 -
checkpoint_format
设置训练时估计器保存的检查点的格式。可能是saver
或checkpoint
,具体取决于是否保存来自tf.train.Saver
或tf.train.Checkpoint
的检查点。此参数当前默认为saver
。当 2.0 发布时,默认值为checkpoint
。估计器使用基于名称的tf.train.Saver
检查点,而 Keras 模型使用来自tf.train.Checkpoint
的基于对象的检查点。目前,仅函数模型和顺序模型支持从model_to_estimator
保存基于对象的检查点。默认为'saver'。 -
metric_names_map
可选字典将 Keras 模型输出指标名称映射到自定义名称。这可用于覆盖多 IO 模型用例中的默认 Keras 模型输出指标名称,并为 Estimator 中的eval_metric_ops
提供自定义名称。 Keras 模型度量名称可以使用model.metrics_names
获得,不包括任何损失度量,例如总损失和输出损失。例如,如果您的 Keras 模型有两个输出out_1
和out_2
,具有mse
损失和acc
指标,那么model.metrics_names
将是['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']
。不包括损失指标的模型指标名称将是['out_1_acc', 'out_2_acc']
。 -
export_outputs
可选字典。这可用于覆盖多 IO 模型用例中的默认 Keras 模型输出导出,并为export_outputs
在tf.estimator.EstimatorSpec
.默认为无,相当于 {'serving_default':tf.estimator.export.PredictOutput
}。如果不是 None,则键必须与model.output_names
.一个字典{name:output}
其中:- 名称:此输出的任意名称。
- 输出:
ExportOutput
类,例如ClassificationOutput
,RegressionOutput
或PredictOutput
。 Single-headed 模型只需要在这个字典中指定一个条目。 Multi-headed 模型应为每个磁头指定一个条目,其中一个必须使用tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
命名 如果未提供条目,则将创建默认的PredictOutput
映射到predictions
。
返回
- 来自给定 keras 模型的 Estimator。
抛出
-
ValueError
如果既没有给出 keras_model 也没有给出 keras_model_path。 -
ValueError
如果同时给出 keras_model 和 keras_model_path。 -
ValueError
如果keras_model_path 是 GCS URI。 -
ValueError
如果 keras_model 尚未编译。 -
ValueError
如果给出了无效的checkpoint_format。
如果您使用依赖于 Estimator 的基础架构或其他工具,您仍然可以构建 Keras 模型并使用 model_to_estimator 将 Keras 模型转换为 Estimator 以用于下游系统。
有关使用示例,请参阅:从 Keras 模型创建估算器。
样品重量:
model_to_estimator
返回的估计器被配置为可以处理样本权重(类似于 keras_model.fit(x, y, sample_weights)
)。
要在训练或评估 Estimator 时传递样本权重,输入函数返回的第一项应该是带有键 features
和 sample_weights
的字典。下面的例子:
keras_model = tf.keras.Model(...)
keras_model.compile(...)
estimator = tf.keras.estimator.model_to_estimator(keras_model)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
带有自定义导出签名的示例:
inputs = {'a':tf.keras.Input(..., name='a'),
'b':tf.keras.Input(..., name='b')}
outputs = {'c':tf.keras.layers.Dense(..., name='c')(inputs['a']),
'd':tf.keras.layers.Dense(..., name='d')(inputs['b'])}
keras_model = tf.keras.Model(inputs, outputs)
keras_model.compile(...)
export_outputs = {'c':tf.estimator.export.RegressionOutput,
'd':tf.estimator.export.ClassificationOutput}
estimator = tf.keras.estimator.model_to_estimator(
keras_model, export_outputs=export_outputs)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features':features, 'sample_weights':sample_weights},
targets))
estimator.train(input_fn, steps=1)
相关用法
- Python tf.compat.v1.keras.experimental.export_saved_model用法及代码示例
- Python tf.compat.v1.keras.experimental.load_from_saved_model用法及代码示例
- Python tf.compat.v1.keras.initializers.Ones.from_config用法及代码示例
- Python tf.compat.v1.keras.layers.DenseFeatures用法及代码示例
- Python tf.compat.v1.keras.initializers.Zeros.from_config用法及代码示例
- Python tf.compat.v1.keras.utils.track_tf1_style_variables用法及代码示例
- Python tf.compat.v1.keras.initializers.Ones用法及代码示例
- Python tf.compat.v1.keras.initializers.RandomNormal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.glorot_uniform.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.lecun_uniform用法及代码示例
- Python tf.compat.v1.keras.initializers.he_normal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.Orthogonal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.lecun_normal.from_config用法及代码示例
- Python tf.compat.v1.keras.initializers.TruncatedNormal用法及代码示例
- Python tf.compat.v1.keras.initializers.RandomNormal用法及代码示例
- Python tf.compat.v1.keras.layers.enable_v2_dtype_behavior用法及代码示例
- Python tf.compat.v1.keras.initializers.he_uniform用法及代码示例
- Python tf.compat.v1.keras.initializers.Identity.from_config用法及代码示例
- Python tf.compat.v1.keras.callbacks.TensorBoard用法及代码示例
- Python tf.compat.v1.keras.initializers.Constant用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.keras.estimator.model_to_estimator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。