当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.symbol.op.FullyConnected用法及代码示例


用法:

mxnet.symbol.op.FullyConnected(data=None, weight=None, bias=None, num_hidden=_Null, no_bias=_Null, flatten=_Null, name=None, attr=None, out=None, **kwargs)

参数

  • data(Symbol) - 输入数据。
  • weight(Symbol) - 权重矩阵。
  • bias(Symbol) - 偏置参数。
  • num_hidden(int, required) - 输出的隐藏节点数。
  • no_bias(boolean, optional, default=0) - 是否禁用偏差参数。
  • flatten(boolean, optional, default=1) - 是否折叠输入数据张量的第一个轴以外的所有轴。
  • name(string, optional.) - 结果符号的名称。

返回

结果符号。

返回类型

Symbol

应用线性变换:

如果 flatten 设置为 true,则形状为:

  • data(batch_size, x1, x2, …, xn)
  • weight(num_hidden, x1 * x2 * … * xn)
  • bias(num_hidden,)
  • out(batch_size, num_hidden)

如果 flatten 设置为 false,则形状为:

  • data(x1, x2, …, xn, input_dim)
  • weight(num_hidden, input_dim)
  • bias(num_hidden,)
  • out(x1, x2, …, xn, num_hidden)

可学习的参数包括 weightbias

如果 no_bias 设置为 true,则忽略 bias 项。

注意

FullyConnected 的稀疏支持仅限于使用 row_sparse 权重和偏差进行前向评估,其中 weight.indicesbias.indices 的长度必须等于 num_hidden 。这对于使用重要性采样或噪声对比估计训练的row_sparse 权重的模型推断很有用。

要使用‘csr’稀疏数据计算线性变换,建议使用sparse.dot而不是sparse.FullyConnected。

例子

构造一个目标维度为 512 的全连接算子。

>>> data = Variable('data')  # or some constructed NN
>>> op = FullyConnected(data=data,
... num_hidden=512,
... name='FC1')
>>> op
<Symbol FC1>
>>> SymbolDoc.get_output_shape(op, data=(128, 100))
{'FC1_output': (128L, 512L)}

带有ReLU 激活的简单 3 层 MLP:

>>> net = Variable('data')
>>> for i, dim in enumerate([128, 64]):
... net = FullyConnected(data=net, num_hidden=dim, name='FC%d' % i)
... net = Activation(data=net, act_type='relu', name='ReLU%d' % i)
>>> # 10-class predictor (e.g. MNIST)
>>> net = FullyConnected(data=net, num_hidden=10, name='pred')
>>> net
<Symbol pred>
>>> dim_in, dim_out = (3, 4)
>>> x, w, b = test_utils.random_arrays((10, dim_in), (dim_out, dim_in), (dim_out,))
>>> op = FullyConnected(num_hidden=dim_out, name='FC')
>>> out = test_utils.simple_forward(op, FC_data=x, FC_weight=w, FC_bias=b)
>>> # numpy implementation of FullyConnected
>>> out_np = np.dot(x, w.T) + b
>>> test_utils.almost_equal(out, out_np)
True

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.op.FullyConnected。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。