當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.module.BaseModule.forward用法及代碼示例


用法:

forward(data_batch, is_train=None)

參數

  • data_batch(DataBatch) - 可以是任何實現了類似 API 的東西。
  • is_train(bool) - 默認為None, 意思是is_train取值為self.for_training.

前向計算。它支持不同形狀的數據批,例如不同的批大小或不同的圖像大小。如果數據批的重塑與符號或模塊的修改有關,例如更改圖像布局順序或從訓練切換到預測,則需要重新綁定模塊。

例子

>>> import mxnet as mx
>>> from collections import namedtuple
>>> Batch = namedtuple('Batch', ['data'])
>>> data = mx.sym.Variable('data')
>>> out = data * 2
>>> mod = mx.mod.Module(symbol=out, label_names=None)
>>> mod.bind(data_shapes=[('data', (1, 10))])
>>> mod.init_params()
>>> data1 = [mx.nd.ones((1, 10))]
>>> mod.forward(Batch(data1))
>>> print mod.get_outputs()[0].asnumpy()
[[ 2.  2.  2.  2.  2.  2.  2.  2.  2.  2.]]
>>> # Forward with data batch of different shape
>>> data2 = [mx.nd.ones((3, 5))]
>>> mod.forward(Batch(data2))
>>> print mod.get_outputs()[0].asnumpy()
[[ 2.  2.  2.  2.  2.]
 [ 2.  2.  2.  2.  2.]
 [ 2.  2.  2.  2.  2.]]

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.module.BaseModule.forward。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。