当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.symbol.Symbol.infer_shape用法及代码示例


用法:

infer_shape(*args, **kwargs)

参数

  • *args- 以位置方式显示参数的形状。未知形状可以标记为无。
  • **kwargs- 已知形状的关键字参数。

返回

  • arg_shapes(list of tuple or None) - 参数形状列表。顺序与list_arguments()的顺序相同。
  • out_shapes(list of tuple or None) - 输出形状列表。顺序与list_outputs()的顺序相同。
  • aux_shapes(list of tuple or None) - 辅助状态形状列表。顺序与list_auxiliary_states()的顺序相同。

给定一些参数的已知形状,推断所有参数和所有输出的形状。

此函数以位置方式或关键字参数方式将某些参数的已知形状作为输入。如果没有足够的信息来推断缺失的形状,它会返回一个 None 值的元组。

示例

>>> a = mx.sym.var('a')
>>> b = mx.sym.var('b')
>>> c = a + b
>>> arg_shapes, out_shapes, aux_shapes = c.infer_shape(a=(3,3))
>>> arg_shapes
[(3L, 3L), (3L, 3L)]
>>> out_shapes
[(3L, 3L)]
>>> aux_shapes
[]
>>> c.infer_shape(a=(0,3)) # 0s in shape means unknown dimensions. So, returns None.
(None, None, None)

已知形状的不一致将导致引发错误。请参见以下示例:

>>> data = mx.sym.Variable('data')
>>> out = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=1000)
>>> out = mx.sym.Activation(data=out, act_type='relu')
>>> out = mx.sym.FullyConnected(data=out, name='fc2', num_hidden=10)
>>> weight_shape= (1, 100)
>>> data_shape = (100, 100)
>>> out.infer_shape(data=data_shape, fc1_weight=weight_shape)
Error in operator fc1: Shape inconsistent, Provided=(1,100), inferred shape=(1000,100)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.Symbol.infer_shape。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。