当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.test_utils.check_consistency用法及代码示例


用法:

mxnet.test_utils.check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arg_params=None, aux_params=None, rtol=None, atol=None, raise_on_err=True, ground_truth=None, equal_nan=False, use_uniform=False, rand_type=<class 'numpy.float64'>)

参数

  • sym(Symbol or list of Symbols) - 运行一致性测试的符号。
  • ctx_list(list) - 运行上下文。有关更多详细信息,请参见示例。
  • scale(float, optional) - 内部正态分布的标准差。用于初始化。
  • grad_req(str or list of str or dict of str to str) - 梯度要求。
  • arg_params(dict of input name -> input data) - 用于非辅助输入的数据
  • aux_params(dict of input name -> input data) - 用于辅助输入的数据
  • rtol(float or dictionary dtype->float, optional) - 相对误差容限。
  • atol(float or dictionary dtype->float, optional) - 绝对误差容限。
  • raise_on_err(bool, optional, defaults to True) - 如果错误引发异常(或仅输出异常消息)
  • ground_truth(dict of output name -> data, optional) - 提供了比较理想的结果
  • equal_nan(bool, optional, defaults to False) - 在比较中是否应将 nan 视为平等
  • use_unifrom(bool) - 可选,当 flag 设置为 true 时,生成的随机输入数据遵循均匀分布,而不是正态分布
  • rand_type(np.dtype) - 将随机生成的数据强制转换为这种类型可选,当输入数据通过arg_params 传递时,默认为 np.float64(numpy float 默认)

检查符号为不同的运行上下文提供相同的输出

例子

>>> # create the symbol
>>> sym = mx.sym.Convolution(num_filter=3, kernel=(3,3), name='conv')
>>> # initialize the running context
>>> ctx_list =[{'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}}, {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}, {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float16}}, {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}}, {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}]
>>> check_consistency(sym, ctx_list)
>>> sym = mx.sym.Concat(name='concat', num_args=2)
>>> ctx_list = [{'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}}, {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}}, {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float16, 'concat_arg1': np.float16}}, {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}}, {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),  'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}}]
>>> check_consistency(sym, ctx_list)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.test_utils.check_consistency。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。