当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.module.BaseModule.fit用法及代码示例


用法:

fit(train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=<mxnet.initializer.Uniform object>, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)

参数

  • train_data(DataIter) - 火车DataIter
  • eval_data(DataIter) - 如果不None, 将用作验证集,并评估每个 epoch 后的性能。
  • eval_metric(str or EvalMetric) - 默认为 ‘accuracy’。用于在训练期间显示的性能度量。其他可能的预定义指标是:‘ce’(交叉熵)、‘f1’, ‘mae’、‘mse’, ‘rmse’、‘top_k_accuracy’。
  • epoch_end_callback(function or list of functions) - 每个回调都将使用当前调用epoch,symbol,arg_paramsaux_params.
  • batch_end_callback(function or list of function) - 每个回调将被调用BatchEndParam.
  • kvstore(str or KVStore) - 默认为 ‘local’。
  • optimizer(str or Optimizer) - 默认为 ‘sgd’。
  • optimizer_params(dict) - 默认为(('learning_rate', 0.01),).优化器构造函数的参数。默认值不是字典,只是为了避免对危险的默认值发出 pylint 警告。
  • eval_end_callback(function or list of function) - 这些将在每次完整评估结束时调用,指标涵盖整个评估集。
  • eval_batch_end_callback(function or list of function) - 这些将在评估期间每个小批量结束时调用。
  • initializer(Initializer) - 调用初始化程序以在模块参数尚未初始化时对其进行初始化。
  • arg_params(dict) - 默认为None, 如果不None, 应该是来自已训练模型的现有参数或从检查点(以前保存的模型)加载的参数。在这种情况下,这里的值将用于初始化模块参数,除非它们已经由用户通过调用init_params或者fit.arg_params优先级高于initializer.
  • aux_params(dict) - 默认为None.相似arg_params, 辅助状态除外。
  • allow_missing(bool) - 默认为False.指示是否允许丢失参数arg_paramsaux_params不是None.如果这是True,那么缺失的参数将通过initializer.
  • force_rebind(bool) - 默认为False.如果已经绑定,是否强制重新绑定执行者。
  • force_init(bool) - 默认为False.指示是否强制初始化,即使参数已经初始化。
  • begin_epoch(int) - 默认为 0。表示起始纪元。通常,如果从在第 N 轮的先前训练阶段保存的检查点恢复,则该值应为 N+1。
  • num_epoch(int) - 训练的 epoch 数。
  • sparse_row_id_fn(A callback function) - 函数需要data_batch作为输入并返回 str -> NDArray 的字典。生成的 dict 用于从 kvstore 中提取 row_sparse 参数,其中 str 键是参数的名称,值是要提取的参数的行 ID。

训练模块参数。

查看Module Tutorial 以查看端到端用例。

例子

>>> # An example of using fit for training.
>>> # Assume training dataIter and validation dataIter are ready
>>> # Assume loading a previously checkpointed model
>>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
>>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
...     arg_params=arg_params, aux_params=aux_params,
...     eval_metric='acc', num_epoch=10, begin_epoch=3)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.module.BaseModule.fit。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。