用法:
mxnet.test_utils.check_symbolic_backward(sym, location, out_grads, expected, rtol=None, atol=None, aux_states=None, grad_req='write', ctx=None, grad_stypes=None, equal_nan=False, dtype=<class 'numpy.float32'>)
- sym:(
Symbol
) - 输出符号 - location:(
list of np.ndarray
or
dict of str to np.ndarray
) -评价点- 如果类型是np.ndarray 的列表
包含与
mx.sym.list_arguments
对应的所有 NumPy 数组。
- 如果类型是 str 到 np.ndarray 的字典
包含参数名称及其值之间的映射。
- out_grads:(
None
or
list of np.ndarray
or
dict of str to np.ndarray
) -NumPys 数组对应于sym.outputs,用于传入梯度。- 如果类型是np.ndarray 的列表
包含对应于
exe.outputs
的数组。
- 如果类型是 str 到 np.ndarray 的字典
包含 mxnet.sym list_output() 和 Executor.outputs 之间的映射
- expected:(
list of np.ndarray
or
dict of str to np.ndarray
) -预期梯度值- 如果类型是np.ndarray 的列表
包含对应exe的数组。grad_arrays
- 如果类型是 str 到 np.ndarray 的字典
包含
sym.list_arguments()
和exe.outputs 之间的映射
- rtol:(
None
or
float
) - 相对阈值。如果设置为,将使用默认阈值None
. - atol:(
None
or
float
) - 绝对阈值。如果设置为,将使用默认阈值None
. - aux_states:(
list of np.ndarray
or
dict of str to np.ndarray
) - - grad_req:(
str
or
list of str
or
dict of str to str
,
optional
) - 梯度要求。 ‘write’, ‘add’ or ‘null’。 - ctx:(mxnet.context.Context
,
optional
) - 运行上下文。 - grad_stypes:(
dict of str->str
) - 将参数名称映射到渐变 stype 的字典 - equal_nan:(
Boolean
) - 如果是真的,nan
是检查等效性的有效值(即nan
==nan
) - dtype:(
np.float16
or
np.float32
or
np.float64
) - mx.nd.array 的数据类型
- sym:(
参数:
将符号的反向结果与预期结果进行比较。如果向后结果与预期结果不同,则打印错误消息。
示例:
>>> lhs = mx.symbol.Variable('lhs') >>> rhs = mx.symbol.Variable('rhs') >>> sym_add = mx.symbol.elemwise_add(lhs, rhs) >>> mat1 = np.array([[1, 2], [3, 4]]) >>> mat2 = np.array([[5, 6], [7, 8]]) >>> grad1 = mx.nd.zeros(shape) >>> grad2 = mx.nd.zeros(shape) >>> exec_add = sym_add.bind(default_context(), args={'lhs': mat1, 'rhs': mat2}, ... args_grad={'lhs': grad1, 'rhs': grad2}, grad_req={'lhs': 'write', 'rhs': 'write'}) >>> exec_add.forward(is_train=True) >>> ograd = mx.nd.ones(shape) >>> grad_expected = ograd.copy().asnumpy() >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])
相关用法
- Python mxnet.test_utils.check_symbolic_forward用法及代码示例
- Python mxnet.test_utils.check_consistency用法及代码示例
- Python mxnet.test_utils.chi_square_check用法及代码示例
- Python mxnet.test_utils.get_zip_data用法及代码示例
- Python mxnet.test_utils.rand_sparse_ndarray用法及代码示例
- Python mxnet.test_utils.var_check用法及代码示例
- Python mxnet.test_utils.get_bz2_data用法及代码示例
- Python mxnet.test_utils.mean_check用法及代码示例
- Python mxnet.symbol.op.broadcast_logical_xor用法及代码示例
- Python mxnet.ndarray.op.uniform用法及代码示例
- Python mxnet.symbol.op.log_softmax用法及代码示例
- Python mxnet.symbol.space_to_depth用法及代码示例
- Python mxnet.ndarray.op.sample_negative_binomial用法及代码示例
- Python mxnet.ndarray.NDArray.ndim用法及代码示例
- Python mxnet.module.BaseModule.get_outputs用法及代码示例
- Python mxnet.module.BaseModule.forward用法及代码示例
- Python mxnet.symbol.random_pdf_poisson用法及代码示例
- Python mxnet.ndarray.op.khatri_rao用法及代码示例
- Python mxnet.ndarray.op.unravel_index用法及代码示例
- Python mxnet.symbol.argmin用法及代码示例
注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.test_utils.check_symbolic_backward。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。