當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.module.BaseModule.fit用法及代碼示例


用法:

fit(train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=<mxnet.initializer.Uniform object>, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)

參數

  • train_data(DataIter) - 火車DataIter
  • eval_data(DataIter) - 如果不None, 將用作驗證集,並評估每個 epoch 後的性能。
  • eval_metric(str or EvalMetric) - 默認為 ‘accuracy’。用於在訓練期間顯示的性能度量。其他可能的預定義指標是:‘ce’(交叉熵)、‘f1’, ‘mae’、‘mse’, ‘rmse’、‘top_k_accuracy’。
  • epoch_end_callback(function or list of functions) - 每個回調都將使用當前調用epoch,symbol,arg_paramsaux_params.
  • batch_end_callback(function or list of function) - 每個回調將被調用BatchEndParam.
  • kvstore(str or KVStore) - 默認為 ‘local’。
  • optimizer(str or Optimizer) - 默認為 ‘sgd’。
  • optimizer_params(dict) - 默認為(('learning_rate', 0.01),).優化器構造函數的參數。默認值不是字典,隻是為了避免對危險的默認值發出 pylint 警告。
  • eval_end_callback(function or list of function) - 這些將在每次完整評估結束時調用,指標涵蓋整個評估集。
  • eval_batch_end_callback(function or list of function) - 這些將在評估期間每個小批量結束時調用。
  • initializer(Initializer) - 調用初始化程序以在模塊參數尚未初始化時對其進行初始化。
  • arg_params(dict) - 默認為None, 如果不None, 應該是來自已訓練模型的現有參數或從檢查點(以前保存的模型)加載的參數。在這種情況下,這裏的值將用於初始化模塊參數,除非它們已經由用戶通過調用init_params或者fit.arg_params優先級高於initializer.
  • aux_params(dict) - 默認為None.相似arg_params, 輔助狀態除外。
  • allow_missing(bool) - 默認為False.指示是否允許丟失參數arg_paramsaux_params不是None.如果這是True,那麽缺失的參數將通過initializer.
  • force_rebind(bool) - 默認為False.如果已經綁定,是否強製重新綁定執行者。
  • force_init(bool) - 默認為False.指示是否強製初始化,即使參數已經初始化。
  • begin_epoch(int) - 默認為 0。表示起始紀元。通常,如果從在第 N 輪的先前訓練階段保存的檢查點恢複,則該值應為 N+1。
  • num_epoch(int) - 訓練的 epoch 數。
  • sparse_row_id_fn(A callback function) - 函數需要data_batch作為輸入並返回 str -> NDArray 的字典。生成的 dict 用於從 kvstore 中提取 row_sparse 參數,其中 str 鍵是參數的名稱,值是要提取的參數的行 ID。

訓練模塊參數。

查看Module Tutorial 以查看端到端用例。

例子

>>> # An example of using fit for training.
>>> # Assume training dataIter and validation dataIter are ready
>>> # Assume loading a previously checkpointed model
>>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
>>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
...     arg_params=arg_params, aux_params=aux_params,
...     eval_metric='acc', num_epoch=10, begin_epoch=3)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.module.BaseModule.fit。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。