用法:
mxnet.test_utils.check_symbolic_backward(sym, location, out_grads, expected, rtol=None, atol=None, aux_states=None, grad_req='write', ctx=None, grad_stypes=None, equal_nan=False, dtype=<class 'numpy.float32'>)
- sym:(
Symbol
) - 輸出符號 - location:(
list of np.ndarray
or
dict of str to np.ndarray
) -評價點- 如果類型是np.ndarray 的列表
包含與
mx.sym.list_arguments
對應的所有 NumPy 數組。
- 如果類型是 str 到 np.ndarray 的字典
包含參數名稱及其值之間的映射。
- out_grads:(
None
or
list of np.ndarray
or
dict of str to np.ndarray
) -NumPys 數組對應於sym.outputs,用於傳入梯度。- 如果類型是np.ndarray 的列表
包含對應於
exe.outputs
的數組。
- 如果類型是 str 到 np.ndarray 的字典
包含 mxnet.sym list_output() 和 Executor.outputs 之間的映射
- expected:(
list of np.ndarray
or
dict of str to np.ndarray
) -預期梯度值- 如果類型是np.ndarray 的列表
包含對應exe的數組。grad_arrays
- 如果類型是 str 到 np.ndarray 的字典
包含
sym.list_arguments()
和exe.outputs 之間的映射
- rtol:(
None
or
float
) - 相對閾值。如果設置為,將使用默認閾值None
. - atol:(
None
or
float
) - 絕對閾值。如果設置為,將使用默認閾值None
. - aux_states:(
list of np.ndarray
or
dict of str to np.ndarray
) - - grad_req:(
str
or
list of str
or
dict of str to str
,
optional
) - 梯度要求。 ‘write’, ‘add’ or ‘null’。 - ctx:(mxnet.context.Context
,
optional
) - 運行上下文。 - grad_stypes:(
dict of str->str
) - 將參數名稱映射到漸變 stype 的字典 - equal_nan:(
Boolean
) - 如果是真的,nan
是檢查等效性的有效值(即nan
==nan
) - dtype:(
np.float16
or
np.float32
or
np.float64
) - mx.nd.array 的數據類型
- sym:(
參數:
將符號的反向結果與預期結果進行比較。如果向後結果與預期結果不同,則打印錯誤消息。
示例:
>>> lhs = mx.symbol.Variable('lhs') >>> rhs = mx.symbol.Variable('rhs') >>> sym_add = mx.symbol.elemwise_add(lhs, rhs) >>> mat1 = np.array([[1, 2], [3, 4]]) >>> mat2 = np.array([[5, 6], [7, 8]]) >>> grad1 = mx.nd.zeros(shape) >>> grad2 = mx.nd.zeros(shape) >>> exec_add = sym_add.bind(default_context(), args={'lhs': mat1, 'rhs': mat2}, ... args_grad={'lhs': grad1, 'rhs': grad2}, grad_req={'lhs': 'write', 'rhs': 'write'}) >>> exec_add.forward(is_train=True) >>> ograd = mx.nd.ones(shape) >>> grad_expected = ograd.copy().asnumpy() >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])
相關用法
- Python mxnet.test_utils.check_symbolic_forward用法及代碼示例
- Python mxnet.test_utils.check_consistency用法及代碼示例
- Python mxnet.test_utils.chi_square_check用法及代碼示例
- Python mxnet.test_utils.get_zip_data用法及代碼示例
- Python mxnet.test_utils.rand_sparse_ndarray用法及代碼示例
- Python mxnet.test_utils.var_check用法及代碼示例
- Python mxnet.test_utils.get_bz2_data用法及代碼示例
- Python mxnet.test_utils.mean_check用法及代碼示例
- Python mxnet.symbol.op.broadcast_logical_xor用法及代碼示例
- Python mxnet.ndarray.op.uniform用法及代碼示例
- Python mxnet.symbol.op.log_softmax用法及代碼示例
- Python mxnet.symbol.space_to_depth用法及代碼示例
- Python mxnet.ndarray.op.sample_negative_binomial用法及代碼示例
- Python mxnet.ndarray.NDArray.ndim用法及代碼示例
- Python mxnet.module.BaseModule.get_outputs用法及代碼示例
- Python mxnet.module.BaseModule.forward用法及代碼示例
- Python mxnet.symbol.random_pdf_poisson用法及代碼示例
- Python mxnet.ndarray.op.khatri_rao用法及代碼示例
- Python mxnet.ndarray.op.unravel_index用法及代碼示例
- Python mxnet.symbol.argmin用法及代碼示例
注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.test_utils.check_symbolic_backward。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。