當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.symbol.Symbol.infer_shape用法及代碼示例


用法:

infer_shape(*args, **kwargs)

參數

  • *args- 以位置方式顯示參數的形狀。未知形狀可以標記為無。
  • **kwargs- 已知形狀的關鍵字參數。

返回

  • arg_shapes(list of tuple or None) - 參數形狀列表。順序與list_arguments()的順序相同。
  • out_shapes(list of tuple or None) - 輸出形狀列表。順序與list_outputs()的順序相同。
  • aux_shapes(list of tuple or None) - 輔助狀態形狀列表。順序與list_auxiliary_states()的順序相同。

給定一些參數的已知形狀,推斷所有參數和所有輸出的形狀。

此函數以位置方式或關鍵字參數方式將某些參數的已知形狀作為輸入。如果沒有足夠的信息來推斷缺失的形狀,它會返回一個 None 值的元組。

示例

>>> a = mx.sym.var('a')
>>> b = mx.sym.var('b')
>>> c = a + b
>>> arg_shapes, out_shapes, aux_shapes = c.infer_shape(a=(3,3))
>>> arg_shapes
[(3L, 3L), (3L, 3L)]
>>> out_shapes
[(3L, 3L)]
>>> aux_shapes
[]
>>> c.infer_shape(a=(0,3)) # 0s in shape means unknown dimensions. So, returns None.
(None, None, None)

已知形狀的不一致將導致引發錯誤。請參見以下示例:

>>> data = mx.sym.Variable('data')
>>> out = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=1000)
>>> out = mx.sym.Activation(data=out, act_type='relu')
>>> out = mx.sym.FullyConnected(data=out, name='fc2', num_hidden=10)
>>> weight_shape= (1, 100)
>>> data_shape = (100, 100)
>>> out.infer_shape(data=data_shape, fc1_weight=weight_shape)
Error in operator fc1: Shape inconsistent, Provided=(1,100), inferred shape=(1000,100)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.symbol.Symbol.infer_shape。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。