当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.estimator.experimental.InMemoryEvaluatorHook用法及代码示例


在没有检查点的情况下在训练中运行评估的挂钩。

继承自:SessionRunHook

用法

tf.estimator.experimental.InMemoryEvaluatorHook(
    estimator, input_fn, steps=None, hooks=None, name=None, every_n_iter=100
)

参数

  • estimator 调用评估的tf.estimator.Estimator 实例。
  • input_fn 相当于input_fn参数为estimator.evaluate.构造用于评估的输入数据的函数。看创建输入函数了解更多信息。该函数应构造并返回以下内容之一:
    • 'tf.data.Dataset' 对象:Dataset 对象的输出必须是具有与以下相同约束的元组(特征、标签)。
    • 元组(特征,标签):其中 featuresTensor 或字符串特征名称字典 TensorlabelsTensor 或字符串标签名称字典 Tensorfeatureslabels 都被 model_fn 消耗。它们应该满足输入对model_fn 的期望。
  • steps 等效于 steps arg 到 estimator.evaluate 。评估模型的步骤数。如果 None ,评估直到 input_fn 引发 end-of-input 异常。
  • hooks 等效于 hooks arg 到 estimator.evaluateSessionRunHook 子类实例列表。用于评估调用中的回调。
  • name 等效于 name arg 到 estimator.evaluate 。如果用户需要对不同的数据集(例如训练数据与测试数据)运行多个评估,则评估的名称。不同评估的指标保存在单独的文件夹中,并单独显示在 tensorboard 中。
  • every_n_iter int ,每 N 次训练迭代运行一次评估器。

抛出

  • ValueError 如果every_n_iter 是非正数或者不是单机训练

例子:

def train_input_fn():
  ...
  return train_dataset

def eval_input_fn():
  ...
  return eval_dataset

estimator = tf.estimator.DNNClassifier(...)

evaluator = tf.estimator.experimental.InMemoryEvaluatorHook(
    estimator, eval_input_fn)
estimator.train(train_input_fn, hooks=[evaluator])

这种方法的当前局限性是:

  • 它不支持multi-node 分布式模式。
  • 它不支持变量以外的可保存对象(例如增强树支持)
  • 它不支持自定义保护程序逻辑(例如 ExponentialMovingAverage 支持)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.experimental.InMemoryEvaluatorHook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。