當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.test_utils.check_consistency用法及代碼示例


用法:

mxnet.test_utils.check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arg_params=None, aux_params=None, rtol=None, atol=None, raise_on_err=True, ground_truth=None, equal_nan=False, use_uniform=False, rand_type=<class 'numpy.float64'>)

參數

  • sym(Symbol or list of Symbols) - 運行一致性測試的符號。
  • ctx_list(list) - 運行上下文。有關更多詳細信息,請參見示例。
  • scale(float, optional) - 內部正態分布的標準差。用於初始化。
  • grad_req(str or list of str or dict of str to str) - 梯度要求。
  • arg_params(dict of input name -> input data) - 用於非輔助輸入的數據
  • aux_params(dict of input name -> input data) - 用於輔助輸入的數據
  • rtol(float or dictionary dtype->float, optional) - 相對誤差容限。
  • atol(float or dictionary dtype->float, optional) - 絕對誤差容限。
  • raise_on_err(bool, optional, defaults to True) - 如果錯誤引發異常(或僅輸出異常消息)
  • ground_truth(dict of output name -> data, optional) - 提供了比較理想的結果
  • equal_nan(bool, optional, defaults to False) - 在比較中是否應將 nan 視為平等
  • use_unifrom(bool) - 可選,當 flag 設置為 true 時,生成的隨機輸入數據遵循均勻分布,而不是正態分布
  • rand_type(np.dtype) - 將隨機生成的數據強製轉換為這種類型可選,當輸入數據通過arg_params 傳遞時,默認為 np.float64(numpy float 默認)

檢查符號為不同的運行上下文提供相同的輸出

例子

>>> # create the symbol
>>> sym = mx.sym.Convolution(num_filter=3, kernel=(3,3), name='conv')
>>> # initialize the running context
>>> ctx_list =[{'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}}, {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}, {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float16}}, {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}}, {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}]
>>> check_consistency(sym, ctx_list)
>>> sym = mx.sym.Concat(name='concat', num_args=2)
>>> ctx_list = [{'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}}, {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}}, {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float16, 'concat_arg1': np.float16}}, {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}}, {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}}]
>>> check_consistency(sym, ctx_list)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.test_utils.check_consistency。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。