用于训练和评估 TensorFlow 模型的 Estimator 类。
警告:不建议将估算器用于新代码。估算器运行tf.compat.v1.Session-style 代码更难正确编写,并且可能出现意外行为,尤其是与 TF 2 代码结合使用时。估算器确实属于我们的兼容性保证,但不会收到除安全漏洞以外的任何修复。见迁移指南详情。
继承自:Estimator
用法
tf.estimator.Estimator(
model_fn, model_dir=None, config=None, params=None, warm_start_from=None
)
参数
-
model_fn
模型函数。遵循签名:features
-- 这是从input_fn
返回的第一项,传递给train
,evaluate
和predict
。这应该是一个相同的tf.Tensor
或dict
。labels
-- 这是从传递给train
,evaluate
和predict
的input_fn
返回的第二项。这应该是相同的单个tf.Tensor
或dict
(对于multi-head 型号)。如果 mode 是tf.estimator.ModeKeys.PREDICT
,labels=None
将被传递。如果model_fn
的签名不接受mode
,则model_fn
必须仍然能够处理labels=None
。mode
-- 可选。指定这是训练、评估还是预测。见tf.estimator.ModeKeys
。params
-- 可选的dict
超参数。将接收params
参数中传递给 Estimator 的内容。这允许通过超参数调整来配置 Estimator。config
-- 可选的estimator.RunConfig
对象。将接收作为其config
参数或默认值传递给 Estimator 的内容。允许根据num_ps_replicas
或model_dir
等配置在model_fn
中进行设置。- 返回 --
tf.estimator.EstimatorSpec
-
model_dir
保存模型参数、图形等的目录。这也可用于将检查点从目录加载到估计器中,以继续训练先前保存的模型。如果PathLike
对象,路径将被解析。如果None
,如果设置,将使用config
中的 model_dir。如果两者都设置,则它们必须相同。如果两者都是None
,将使用临时目录。 -
config
estimator.RunConfig
配置对象。 -
params
dict
的超参数将被传递到model_fn
。键是参数的名称,值是基本的 Python 类型。 -
warm_start_from
检查点或 SavedModel 的可选字符串文件路径以进行热启动,或tf.estimator.WarmStartSettings
对象以完全配置热启动。如果没有,只有 TRAINABLE 变量是热启动的。如果提供了字符串文件路径而不是tf.estimator.WarmStartSettings
,则所有变量都是热启动的,并且假定词汇表和tf.Tensor
名称不变。
抛出
-
ValueError
model_fn
的参数与params
不匹配。 -
ValueError
如果这是通过子类调用的,并且该类覆盖了Estimator
的成员。
属性
-
config
-
export_savedmodel
-
model_dir
-
model_fn
返回绑定到self.params
的model_fn
。 -
params
Estimator
对象包装了由 model_fn
指定的模型,该模型在给定输入和许多其他参数的情况下,返回执行训练、评估或预测所需的操作。
所有输出(检查点、事件文件等)都写入 model_dir
或其子目录。如果未设置model_dir
,则使用临时目录。
config
参数可以传递 tf.estimator.RunConfig
对象,其中包含有关执行环境的信息。如果 model_fn
有一个名为 "config" 的参数(并且以相同的方式输入函数),它会被传递给 model_fn
。如果未传递 config
参数,则由 Estimator
实例化。不传递配置意味着使用对本地执行有用的默认值。 Estimator
使配置对模型可用(例如,允许基于可用工人数量的专业化),并且还使用它的一些字段来控制内部,特别是关于检查点。
params
参数包含超参数。如果 model_fn
有一个名为 "params" 的参数,它会被传递给 model_fn
,并以相同的方式传递给输入函数。 Estimator
只传递参数,它不检查它。因此,params
的结构完全取决于开发人员。
Estimator
的任何方法都不能在子类中被覆盖(其构造函数强制执行此操作)。子类应该使用model_fn
来配置基类,并且可以添加实现特殊函数的方法。
有关详细信息,请参阅估算器。
热启动 Estimator
:
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
有关热启动配置的更多详细信息,请参阅tf.estimator.WarmStartSettings
。
相关用法
- Python tf.estimator.TrainSpec用法及代码示例
- Python tf.estimator.LogisticRegressionHead用法及代码示例
- Python tf.estimator.MultiHead用法及代码示例
- Python tf.estimator.PoissonRegressionHead用法及代码示例
- Python tf.estimator.WarmStartSettings用法及代码示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代码示例
- Python tf.estimator.RunConfig用法及代码示例
- Python tf.estimator.MultiLabelHead用法及代码示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代码示例
- Python tf.estimator.BaselineEstimator用法及代码示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代码示例
- Python tf.estimator.experimental.LinearSDCA用法及代码示例
- Python tf.estimator.experimental.RNNClassifier用法及代码示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代码示例
- Python tf.estimator.LinearRegressor用法及代码示例
- Python tf.estimator.LinearEstimator用法及代码示例
- Python tf.estimator.DNNClassifier用法及代码示例
- Python tf.estimator.BaselineClassifier用法及代码示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代码示例
- Python tf.estimator.train_and_evaluate用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.Estimator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。