当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.module.BaseModule.forward用法及代码示例


用法:

forward(data_batch, is_train=None)

参数

  • data_batch(DataBatch) - 可以是任何实现了类似 API 的东西。
  • is_train(bool) - 默认为None, 意思是is_train取值为self.for_training.

前向计算。它支持不同形状的数据批,例如不同的批大小或不同的图像大小。如果数据批的重塑与符号或模块的修改有关,例如更改图像布局顺序或从训练切换到预测,则需要重新绑定模块。

例子

>>> import mxnet as mx
>>> from collections import namedtuple
>>> Batch = namedtuple('Batch', ['data'])
>>> data = mx.sym.Variable('data')
>>> out = data * 2
>>> mod = mx.mod.Module(symbol=out, label_names=None)
>>> mod.bind(data_shapes=[('data', (1, 10))])
>>> mod.init_params()
>>> data1 = [mx.nd.ones((1, 10))]
>>> mod.forward(Batch(data1))
>>> print mod.get_outputs()[0].asnumpy()
[[ 2.  2.  2.  2.  2.  2.  2.  2.  2.  2.]]
>>> # Forward with data batch of different shape
>>> data2 = [mx.nd.ones((3, 5))]
>>> mod.forward(Batch(data2))
>>> print mod.get_outputs()[0].asnumpy()
[[ 2.  2.  2.  2.  2.]
 [ 2.  2.  2.  2.  2.]
 [ 2.  2.  2.  2.  2.]]

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.module.BaseModule.forward。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。