用法:
fit(train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=<mxnet.initializer.Uniform object>, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)
- train_data:(
DataIter
) - 火車DataIter - eval_data:(
DataIter
) - 如果不None
, 將用作驗證集,並評估每個 epoch 後的性能。 - eval_metric:(
str
or
EvalMetric
) - 默認為 ‘accuracy’。用於在訓練期間顯示的性能度量。其他可能的預定義指標是:‘ce’(交叉熵)、‘f1’, ‘mae’、‘mse’, ‘rmse’、‘top_k_accuracy’。 - epoch_end_callback:(
function
or
list of functions
) - 每個回調都將使用當前調用epoch
,symbol
,arg_params
和aux_params
. - batch_end_callback:(
function
or
list of function
) - 每個回調將被調用BatchEndParam
. - kvstore:(
str
or
KVStore
) - 默認為 ‘local’。 - optimizer:(
str
or
Optimizer
) - 默認為 ‘sgd’。 - optimizer_params:(
dict
) - 默認為(('learning_rate', 0.01),)
.優化器構造函數的參數。默認值不是字典,隻是為了避免對危險的默認值發出 pylint 警告。 - eval_end_callback:(
function
or
list of function
) - 這些將在每次完整評估結束時調用,指標涵蓋整個評估集。 - eval_batch_end_callback:(
function
or
list of function
) - 這些將在評估期間每個小批量結束時調用。 - initializer:(
Initializer
) - 調用初始化程序以在模塊參數尚未初始化時對其進行初始化。 - arg_params:(
dict
) - 默認為None
, 如果不None
, 應該是來自已訓練模型的現有參數或從檢查點(以前保存的模型)加載的參數。在這種情況下,這裏的值將用於初始化模塊參數,除非它們已經由用戶通過調用init_params
或者fit
.arg_params
優先級高於initializer
. - aux_params:(
dict
) - 默認為None
.相似arg_params
, 輔助狀態除外。 - allow_missing:(
bool
) - 默認為False
.指示是否允許丟失參數arg_params
和aux_params
不是None
.如果這是True
,那麽缺失的參數將通過initializer
. - force_rebind:(
bool
) - 默認為False
.如果已經綁定,是否強製重新綁定執行者。 - force_init:(
bool
) - 默認為False
.指示是否強製初始化,即使參數已經初始化。 - begin_epoch:(
int
) - 默認為 0。表示起始紀元。通常,如果從在第 N 輪的先前訓練階段保存的檢查點恢複,則該值應為 N+1。 - num_epoch:(
int
) - 訓練的 epoch 數。 - sparse_row_id_fn:(
A callback function
) - 函數需要data_batch
作為輸入並返回 str -> NDArray 的字典。生成的 dict 用於從 kvstore 中提取 row_sparse 參數,其中 str 鍵是參數的名稱,值是要提取的參數的行 ID。
- train_data:(
參數:
訓練模塊參數。
查看Module Tutorial 以查看端到端用例。
例子:
>>> # An example of using fit for training. >>> # Assume training dataIter and validation dataIter are ready >>> # Assume loading a previously checkpointed model >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd', ... optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, ... arg_params=arg_params, aux_params=aux_params, ... eval_metric='acc', num_epoch=10, begin_epoch=3)
相關用法
- Python mxnet.module.BaseModule.forward用法及代碼示例
- Python mxnet.module.BaseModule.get_outputs用法及代碼示例
- Python mxnet.module.BaseModule.bind用法及代碼示例
- Python mxnet.module.BaseModule.init_params用法及代碼示例
- Python mxnet.module.BaseModule.get_params用法及代碼示例
- Python mxnet.module.BaseModule.set_params用法及代碼示例
- Python mxnet.module.BaseModule.update用法及代碼示例
- Python mxnet.module.BaseModule.iter_predict用法及代碼示例
- Python mxnet.module.BaseModule.save_params用法及代碼示例
- Python mxnet.module.BaseModule.init_optimizer用法及代碼示例
- Python mxnet.module.BaseModule.score用法及代碼示例
- Python mxnet.module.BaseModule.update_metric用法及代碼示例
- Python mxnet.module.BaseModule.predict用法及代碼示例
- Python mxnet.module.BaseModule.get_input_grads用法及代碼示例
- Python mxnet.module.BaseModule.backward用法及代碼示例
- Python mxnet.module.BaseModule.load_params用法及代碼示例
- Python mxnet.module.BaseModule用法及代碼示例
- Python mxnet.module.BucketingModule.set_params用法及代碼示例
- Python mxnet.module.SequentialModule.add用法及代碼示例
- Python mxnet.module.Module.set_params用法及代碼示例
注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.module.BaseModule.fit。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。