用法:
fit(train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=<mxnet.initializer.Uniform object>, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)
- train_data:(
DataIter
) - 火车DataIter - eval_data:(
DataIter
) - 如果不None
, 将用作验证集,并评估每个 epoch 后的性能。 - eval_metric:(
str
or
EvalMetric
) - 默认为 ‘accuracy’。用于在训练期间显示的性能度量。其他可能的预定义指标是:‘ce’(交叉熵)、‘f1’, ‘mae’、‘mse’, ‘rmse’、‘top_k_accuracy’。 - epoch_end_callback:(
function
or
list of functions
) - 每个回调都将使用当前调用epoch
,symbol
,arg_params
和aux_params
. - batch_end_callback:(
function
or
list of function
) - 每个回调将被调用BatchEndParam
. - kvstore:(
str
or
KVStore
) - 默认为 ‘local’。 - optimizer:(
str
or
Optimizer
) - 默认为 ‘sgd’。 - optimizer_params:(
dict
) - 默认为(('learning_rate', 0.01),)
.优化器构造函数的参数。默认值不是字典,只是为了避免对危险的默认值发出 pylint 警告。 - eval_end_callback:(
function
or
list of function
) - 这些将在每次完整评估结束时调用,指标涵盖整个评估集。 - eval_batch_end_callback:(
function
or
list of function
) - 这些将在评估期间每个小批量结束时调用。 - initializer:(
Initializer
) - 调用初始化程序以在模块参数尚未初始化时对其进行初始化。 - arg_params:(
dict
) - 默认为None
, 如果不None
, 应该是来自已训练模型的现有参数或从检查点(以前保存的模型)加载的参数。在这种情况下,这里的值将用于初始化模块参数,除非它们已经由用户通过调用init_params
或者fit
.arg_params
优先级高于initializer
. - aux_params:(
dict
) - 默认为None
.相似arg_params
, 辅助状态除外。 - allow_missing:(
bool
) - 默认为False
.指示是否允许丢失参数arg_params
和aux_params
不是None
.如果这是True
,那么缺失的参数将通过initializer
. - force_rebind:(
bool
) - 默认为False
.如果已经绑定,是否强制重新绑定执行者。 - force_init:(
bool
) - 默认为False
.指示是否强制初始化,即使参数已经初始化。 - begin_epoch:(
int
) - 默认为 0。表示起始纪元。通常,如果从在第 N 轮的先前训练阶段保存的检查点恢复,则该值应为 N+1。 - num_epoch:(
int
) - 训练的 epoch 数。 - sparse_row_id_fn:(
A callback function
) - 函数需要data_batch
作为输入并返回 str -> NDArray 的字典。生成的 dict 用于从 kvstore 中提取 row_sparse 参数,其中 str 键是参数的名称,值是要提取的参数的行 ID。
- train_data:(
参数:
训练模块参数。
查看Module Tutorial 以查看端到端用例。
例子:
>>> # An example of using fit for training. >>> # Assume training dataIter and validation dataIter are ready >>> # Assume loading a previously checkpointed model >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd', ... optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, ... arg_params=arg_params, aux_params=aux_params, ... eval_metric='acc', num_epoch=10, begin_epoch=3)
相关用法
- Python mxnet.module.BaseModule.forward用法及代码示例
- Python mxnet.module.BaseModule.get_outputs用法及代码示例
- Python mxnet.module.BaseModule.bind用法及代码示例
- Python mxnet.module.BaseModule.init_params用法及代码示例
- Python mxnet.module.BaseModule.get_params用法及代码示例
- Python mxnet.module.BaseModule.set_params用法及代码示例
- Python mxnet.module.BaseModule.update用法及代码示例
- Python mxnet.module.BaseModule.iter_predict用法及代码示例
- Python mxnet.module.BaseModule.save_params用法及代码示例
- Python mxnet.module.BaseModule.init_optimizer用法及代码示例
- Python mxnet.module.BaseModule.score用法及代码示例
- Python mxnet.module.BaseModule.update_metric用法及代码示例
- Python mxnet.module.BaseModule.predict用法及代码示例
- Python mxnet.module.BaseModule.get_input_grads用法及代码示例
- Python mxnet.module.BaseModule.backward用法及代码示例
- Python mxnet.module.BaseModule.load_params用法及代码示例
- Python mxnet.module.BaseModule用法及代码示例
- Python mxnet.module.BucketingModule.set_params用法及代码示例
- Python mxnet.module.SequentialModule.add用法及代码示例
- Python mxnet.module.Module.set_params用法及代码示例
注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.module.BaseModule.fit。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。