當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.module.BaseModule.predict用法及代碼示例


用法:

predict(eval_data, num_batch=None, merge_batches=True, reset=True, always_output_list=False, sparse_row_id_fn=None)

參數

  • eval_data(DataIter or NDArray or numpy array) - 用於運行預測的評估數據。
  • num_batch(int) - 默認為None, 表示運行數據迭代器中的所有批次。
  • merge_batches(bool) - 默認為True,請參見上麵的返回值。
  • reset(bool) - 默認為True, 表示我們是否應該在進行預測之前重置數據迭代器。
  • always_output_list(bool) - 默認為False,請參見上麵的返回值。
  • sparse_row_id_fn(A callback function) - 函數需要data_batch作為輸入並返回 str -> NDArray 的字典。生成的 dict 用於從 kvstore 中提取 row_sparse 參數,其中 str 鍵是參數的名稱,值是要提取的參數的行 ID。

返回

預測結果。

返回類型

NDArray 列表或 NDArray 列表

運行預測並收集輸出。

merge_batchesTrue(默認情況下)時,返回值將是一個列表 [out1, out2, out3] ,其中每個元素都是通過連接所有小批量的輸出而形成的。當 always_output_listFalse (默認情況下)時,如果是單個輸出,則返回 out1 而不是 [out1]

merge_batchesFalse 時,返回值將是一個嵌套列表,如 [[out1_batch1, out2_batch1], [out1_batch2], ...] 。這種模式很有用,因為在某些情況下(例如分桶),模塊不一定會產生相同數量的輸出。

結果中的對象具有類型 NDArray 。如果您需要使用 numpy 數組,隻需在每個 NDArray 上調用 .asnumpy() 即可。

例子

>>> # An example of using `predict` for prediction.
>>> # Predict on the first 10 batches of val_dataiter
>>> mod.predict(eval_data=val_dataiter, num_batch=10)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.module.BaseModule.predict。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。