當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.test_utils.check_symbolic_backward用法及代碼示例


用法:

mxnet.test_utils.check_symbolic_backward(sym, location, out_grads, expected, rtol=None, atol=None, aux_states=None, grad_req='write', ctx=None, grad_stypes=None, equal_nan=False, dtype=<class 'numpy.float32'>)

參數

  • sym(Symbol) - 輸出符號
  • location(list of np.ndarray or dict of str to np.ndarray) -評價點
    • 如果類型是np.ndarray 的列表

      包含與 mx.sym.list_arguments 對應的所有 NumPy 數組。

    • 如果類型是 str 到 np.ndarray 的字典

      包含參數名稱及其值之間的映射。

  • out_grads(None or list of np.ndarray or dict of str to np.ndarray) -NumPys 數組對應於sym.outputs,用於傳入梯度。
    • 如果類型是np.ndarray 的列表

      包含對應於 exe.outputs 的數組。

    • 如果類型是 str 到 np.ndarray 的字典

      包含 mxnet.sym list_output() 和 Executor.outputs 之間的映射

  • expected(list of np.ndarray or dict of str to np.ndarray) -預期梯度值
    • 如果類型是np.ndarray 的列表

      包含對應exe的數組。grad_arrays

    • 如果類型是 str 到 np.ndarray 的字典

      包含sym.list_arguments() 和exe.outputs 之間的映射

  • rtol(None or float) - 相對閾值。如果設置為,將使用默認閾值None.
  • atol(None or float) - 絕對閾值。如果設置為,將使用默認閾值None.
  • aux_states(list of np.ndarray or dict of str to np.ndarray) -
  • grad_req(str or list of str or dict of str to str, optional) - 梯度要求。 ‘write’, ‘add’ or ‘null’。
  • ctx(mxnet.context.Context, optional) - 運行上下文。
  • grad_stypes(dict of str->str) - 將參數名稱映射到漸變 stype 的字典
  • equal_nan(Boolean) - 如果是真的,nan是檢查等效性的有效值(即nan==nan)
  • dtype(np.float16 or np.float32 or np.float64) - mx.nd.array 的數據類型

將符號的反向結果與預期結果進行比較。如果向後結果與預期結果不同,則打印錯誤消息。

示例

>>> lhs = mx.symbol.Variable('lhs')
>>> rhs = mx.symbol.Variable('rhs')
>>> sym_add = mx.symbol.elemwise_add(lhs, rhs)
>>> mat1 = np.array([[1, 2], [3, 4]])
>>> mat2 = np.array([[5, 6], [7, 8]])
>>> grad1 = mx.nd.zeros(shape)
>>> grad2 = mx.nd.zeros(shape)
>>> exec_add = sym_add.bind(default_context(), args={'lhs': mat1, 'rhs': mat2},
... args_grad={'lhs': grad1, 'rhs': grad2}, grad_req={'lhs': 'write', 'rhs': 'write'})
>>> exec_add.forward(is_train=True)
>>> ograd = mx.nd.ones(shape)
>>> grad_expected = ograd.copy().asnumpy()
>>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.test_utils.check_symbolic_backward。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。