当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.test_utils.check_symbolic_backward用法及代码示例


用法:

mxnet.test_utils.check_symbolic_backward(sym, location, out_grads, expected, rtol=None, atol=None, aux_states=None, grad_req='write', ctx=None, grad_stypes=None, equal_nan=False, dtype=<class 'numpy.float32'>)

参数

  • sym(Symbol) - 输出符号
  • location(list of np.ndarray or dict of str to np.ndarray) -评价点
    • 如果类型是np.ndarray 的列表

      包含与 mx.sym.list_arguments 对应的所有 NumPy 数组。

    • 如果类型是 str 到 np.ndarray 的字典

      包含参数名称及其值之间的映射。

  • out_grads(None or list of np.ndarray or dict of str to np.ndarray) -NumPys 数组对应于sym.outputs,用于传入梯度。
    • 如果类型是np.ndarray 的列表

      包含对应于 exe.outputs 的数组。

    • 如果类型是 str 到 np.ndarray 的字典

      包含 mxnet.sym list_output() 和 Executor.outputs 之间的映射

  • expected(list of np.ndarray or dict of str to np.ndarray) -预期梯度值
    • 如果类型是np.ndarray 的列表

      包含对应exe的数组。grad_arrays

    • 如果类型是 str 到 np.ndarray 的字典

      包含sym.list_arguments() 和exe.outputs 之间的映射

  • rtol(None or float) - 相对阈值。如果设置为,将使用默认阈值None.
  • atol(None or float) - 绝对阈值。如果设置为,将使用默认阈值None.
  • aux_states(list of np.ndarray or dict of str to np.ndarray) -
  • grad_req(str or list of str or dict of str to str, optional) - 梯度要求。 ‘write’, ‘add’ or ‘null’。
  • ctx(mxnet.context.Context, optional) - 运行上下文。
  • grad_stypes(dict of str->str) - 将参数名称映射到渐变 stype 的字典
  • equal_nan(Boolean) - 如果是真的,nan是检查等效性的有效值(即nan==nan)
  • dtype(np.float16 or np.float32 or np.float64) - mx.nd.array 的数据类型

将符号的反向结果与预期结果进行比较。如果向后结果与预期结果不同,则打印错误消息。

示例

>>> lhs = mx.symbol.Variable('lhs')
>>> rhs = mx.symbol.Variable('rhs')
>>> sym_add = mx.symbol.elemwise_add(lhs, rhs)
>>> mat1 = np.array([[1, 2], [3, 4]])
>>> mat2 = np.array([[5, 6], [7, 8]])
>>> grad1 = mx.nd.zeros(shape)
>>> grad2 = mx.nd.zeros(shape)
>>> exec_add = sym_add.bind(default_context(), args={'lhs': mat1, 'rhs': mat2},
... args_grad={'lhs': grad1, 'rhs': grad2}, grad_req={'lhs': 'write', 'rhs': 'write'})
>>> exec_add.forward(is_train=True)
>>> ograd = mx.nd.ones(shape)
>>> grad_expected = ograd.copy().asnumpy()
>>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.test_utils.check_symbolic_backward。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。