当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.keras.estimator.model_to_estimator用法及代码示例


从给定的 keras 模型构造一个 Estimator 实例。

用法

tf.compat.v1.keras.estimator.model_to_estimator(
    keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None,
    config=None, checkpoint_format='saver', metric_names_map=None,
    export_outputs=None
)

参数

  • keras_model 已编译的 Keras 模型对象。此参数与 keras_model_path 互斥。 Estimator 的model_fn 使用模型的结构来克隆模型。默认为 None
  • keras_model_path 以 HDF5 格式保存在磁盘上的已编译 Keras 模型的路径,可以使用 Keras 模型的save() 方法生成。此参数与 keras_model 互斥。默认为 None
  • custom_objects 用于克隆自定义对象的字典。这与不属于此 pip 包的类一起使用。例如,如果用户维护一个继承自 tf.keras.layers.Layerrelu6 类,则传递 custom_objects={'relu6':relu6} 。默认为 None
  • model_dir 用于保存 Estimator 模型参数、图形、TensorBoard 的摘要文件等的目录。如果未设置,将使用 tempfile.mkdtemp 创建目录
  • config RunConfig配置Estimator.允许在model_fn基于配置,例如num_ps_replicas, 或者model_dir.默认为None.如果两者config.model_dirmodel_dir参数(上面)被指定为model_dir 参数优先。
  • checkpoint_format 设置训练时估计器保存的检查点的格式。可能是 savercheckpoint ,具体取决于是否保存来自 tf.train.Savertf.train.Checkpoint 的检查点。此参数当前默认为 saver 。当 2.0 发布时,默认值为 checkpoint 。估计器使用基于名称的 tf.train.Saver 检查点,而 Keras 模型使用来自 tf.train.Checkpoint 的基于对象的检查点。目前,仅函数模型和顺序模型支持从 model_to_estimator 保存基于对象的检查点。默认为'saver'。
  • metric_names_map 可选字典将 Keras 模型输出指标名称映射到自定义名称。这可用于覆盖多 IO 模型用例中的默​​认 Keras 模型输出指标名称,并为 Estimator 中的 eval_metric_ops 提供自定义名称。 Keras 模型度量名称可以使用model.metrics_names 获得,不包括任何损失度量,例如总损失和输出损失。例如,如果您的 Keras 模型有两个输出 out_1out_2 ,具有 mse 损失和 acc 指标,那么 model.metrics_names 将是 ['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc'] 。不包括损失指标的模型指标名称将是 ['out_1_acc', 'out_2_acc']
  • export_outputs 可选字典。这可用于覆盖多 IO 模型用例中的默​​认 Keras 模型输出导出,并为export_outputstf.estimator.EstimatorSpec.默认为无,相当于 {'serving_default':tf.estimator.export.PredictOutput}。如果不是 None,则键必须与model.output_names.一个字典{name:output}其中:
    • 名称:此输出的任意名称。
    • 输出:ExportOutput 类,例如 ClassificationOutput , RegressionOutputPredictOutput 。 Single-headed 模型只需要在这个字典中指定一个条目。 Multi-headed 模型应为每个磁头指定一个条目,其中一个必须使用 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 命名 如果未提供条目,则将创建默认的 PredictOutput 映射到 predictions

返回

  • 来自给定 keras 模型的 Estimator。

抛出

  • ValueError 如果既没有给出 keras_model 也没有给出 keras_model_path。
  • ValueError 如果同时给出 keras_model 和 keras_model_path。
  • ValueError 如果keras_model_path 是 GCS URI。
  • ValueError 如果 keras_model 尚未编译。
  • ValueError 如果给出了无效的checkpoint_format。

如果您使用依赖于 Estimator 的基础架构或其他工具,您仍然可以构建 Keras 模型并使用 model_to_estimator 将 Keras 模型转换为 Estimator 以用于下游系统。

有关使用示例,请参阅:从 Keras 模型创建估算器。

样品重量:

model_to_estimator 返回的估计器被配置为可以处理样本权重(类似于 keras_model.fit(x, y, sample_weights) )。

要在训练或评估 Estimator 时传递样本权重,输入函数返回的第一项应该是带有键 featuressample_weights 的字典。下面的例子:

keras_model = tf.keras.Model(...)
keras_model.compile(...)

estimator = tf.keras.estimator.model_to_estimator(keras_model)

def input_fn():
  return dataset_ops.Dataset.from_tensors(
      ({'features':features, 'sample_weights':sample_weights},
       targets))

estimator.train(input_fn, steps=1)

带有自定义导出签名的示例:

inputs = {'a':tf.keras.Input(..., name='a'),
          'b':tf.keras.Input(..., name='b')}
outputs = {'c':tf.keras.layers.Dense(..., name='c')(inputs['a']),
           'd':tf.keras.layers.Dense(..., name='d')(inputs['b'])}
keras_model = tf.keras.Model(inputs, outputs)
keras_model.compile(...)
export_outputs = {'c':tf.estimator.export.RegressionOutput,
                  'd':tf.estimator.export.ClassificationOutput}

estimator = tf.keras.estimator.model_to_estimator(
    keras_model, export_outputs=export_outputs)

def input_fn():
  return dataset_ops.Dataset.from_tensors(
      ({'features':features, 'sample_weights':sample_weights},
       targets))

estimator.train(input_fn, steps=1)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.keras.estimator.model_to_estimator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。