當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python mxnet.symbol.op.FullyConnected用法及代碼示例

用法:

mxnet.symbol.op.FullyConnected(data=None, weight=None, bias=None, num_hidden=_Null, no_bias=_Null, flatten=_Null, name=None, attr=None, out=None, **kwargs)

參數

  • data(Symbol) - 輸入數據。
  • weight(Symbol) - 權重矩陣。
  • bias(Symbol) - 偏置參數。
  • num_hidden(int, required) - 輸出的隱藏節點數。
  • no_bias(boolean, optional, default=0) - 是否禁用偏差參數。
  • flatten(boolean, optional, default=1) - 是否折疊輸入數據張量的第一個軸以外的所有軸。
  • name(string, optional.) - 結果符號的名稱。

返回

結果符號。

返回類型

Symbol

應用線性變換:

如果 flatten 設置為 true,則形狀為:

  • data(batch_size, x1, x2, …, xn)
  • weight(num_hidden, x1 * x2 * … * xn)
  • bias(num_hidden,)
  • out(batch_size, num_hidden)

如果 flatten 設置為 false,則形狀為:

  • data(x1, x2, …, xn, input_dim)
  • weight(num_hidden, input_dim)
  • bias(num_hidden,)
  • out(x1, x2, …, xn, num_hidden)

可學習的參數包括 weightbias

如果 no_bias 設置為 true,則忽略 bias 項。

注意

FullyConnected 的稀疏支持僅限於使用 row_sparse 權重和偏差進行前向評估,其中 weight.indicesbias.indices 的長度必須等於 num_hidden 。這對於使用重要性采樣或噪聲對比估計訓練的row_sparse 權重的模型推斷很有用。

要使用‘csr’稀疏數據計算線性變換,建議使用sparse.dot而不是sparse.FullyConnected。

例子

構造一個目標維度為 512 的全連接算子。

>>> data = Variable('data')  # or some constructed NN
>>> op = FullyConnected(data=data,
... num_hidden=512,
... name='FC1')
>>> op
<Symbol FC1>
>>> SymbolDoc.get_output_shape(op, data=(128, 100))
{'FC1_output': (128L, 512L)}

帶有ReLU 激活的簡單 3 層 MLP:

>>> net = Variable('data')
>>> for i, dim in enumerate([128, 64]):
... net = FullyConnected(data=net, num_hidden=dim, name='FC%d' % i)
... net = Activation(data=net, act_type='relu', name='ReLU%d' % i)
>>> # 10-class predictor (e.g. MNIST)
>>> net = FullyConnected(data=net, num_hidden=10, name='pred')
>>> net
<Symbol pred>
>>> dim_in, dim_out = (3, 4)
>>> x, w, b = test_utils.random_arrays((10, dim_in), (dim_out, dim_in), (dim_out,))
>>> op = FullyConnected(num_hidden=dim_out, name='FC')
>>> out = test_utils.simple_forward(op, FC_data=x, FC_weight=w, FC_bias=b)
>>> # numpy implementation of FullyConnected
>>> out_np = np.dot(x, w.T) + b
>>> test_utils.almost_equal(out, out_np)
True

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.symbol.op.FullyConnected。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。