当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.estimator.TrainSpec用法及代码示例


train_and_evaluate 调用的"train" 部分的配置。

用法

tf.estimator.TrainSpec(
    input_fn, max_steps=None, hooks=None, saving_listeners=None
)

参数

  • input_fn 提供输入数据以作为小批量进行训练的函数。看预制估算器了解更多信息。该函数应构造并返回以下内容之一:
    • 'tf.data.Dataset' 对象:Dataset 对象的输出必须是具有与以下相同约束的元组(特征、标签)。
    • 元组(特征、标签):其中 features 是 Tensor 或字符串特征名称字典到 Tensor 并且标签是 Tensor 或字符串标签名称字典到 Tensor
  • max_steps Int. 训练模型的总步数为正数。如果 None ,永远训练。训练 input_fn 预计不会生成 OutOfRangeErrorStopIteration 异常。有关详细信息,请参阅train_and_evaluate 停止条件部分。
  • hooks tf.train.SessionRunHook 对象的可迭代对象,可在训练期间在所有工作人员(包括主管)上运行。
  • saving_listeners tf.estimator.CheckpointSaverListener 对象的迭代,在训练期间主要运行。

抛出

  • ValueError 如果任何输入参数无效。
  • TypeError 如果任何参数不是预期的类型。

属性

  • input_fn 字段编号 0 的 namedtuple 别名
  • max_steps 字段编号 1 的 namedtuple 别名
  • hooks 字段编号 2 的 namedtuple 别名
  • saving_listeners 字段编号 3 的 namedtuple 别名

TrainSpec 确定训练的输入数据以及持续时间。可选的钩子在训练的不同阶段运行。

用法:

train_spec = tf.estimator.TrainSpec(
   input_fn=lambda:1,
   max_steps=100,
   hooks=[_StopAtSecsHook(stop_after_secs=10)],
   saving_listeners=[_NewCheckpointListenerForEvaluate(None, 20, None)])
train_spec.saving_listeners[0]._eval_throttle_secs
20
train_spec.hooks[0]._stop_after_secs
10
train_spec.max_steps
100

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.TrainSpec。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。