当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.module.BaseModule.predict用法及代码示例


用法:

predict(eval_data, num_batch=None, merge_batches=True, reset=True, always_output_list=False, sparse_row_id_fn=None)

参数

  • eval_data(DataIter or NDArray or numpy array) - 用于运行预测的评估数据。
  • num_batch(int) - 默认为None, 表示运行数据迭代器中的所有批次。
  • merge_batches(bool) - 默认为True,请参见上面的返回值。
  • reset(bool) - 默认为True, 表示我们是否应该在进行预测之前重置数据迭代器。
  • always_output_list(bool) - 默认为False,请参见上面的返回值。
  • sparse_row_id_fn(A callback function) - 函数需要data_batch作为输入并返回 str -> NDArray 的字典。生成的 dict 用于从 kvstore 中提取 row_sparse 参数,其中 str 键是参数的名称,值是要提取的参数的行 ID。

返回

预测结果。

返回类型

NDArray 列表或 NDArray 列表

运行预测并收集输出。

merge_batchesTrue(默认情况下)时,返回值将是一个列表 [out1, out2, out3] ,其中每个元素都是通过连接所有小批量的输出而形成的。当 always_output_listFalse (默认情况下)时,如果是单个输出,则返回 out1 而不是 [out1]

merge_batchesFalse 时,返回值将是一个嵌套列表,如 [[out1_batch1, out2_batch1], [out1_batch2], ...] 。这种模式很有用,因为在某些情况下(例如分桶),模块不一定会产生相同数量的输出。

结果中的对象具有类型 NDArray 。如果您需要使用 numpy 数组,只需在每个 NDArray 上调用 .asnumpy() 即可。

例子

>>> # An example of using `predict` for prediction.
>>> # Predict on the first 10 batches of val_dataiter
>>> mod.predict(eval_data=val_dataiter, num_batch=10)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.module.BaseModule.predict。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。