當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.module.BaseModule.get_input_grads用法及代碼示例


用法:

get_input_grads(merge_multi_context=True)

參數

merge_multi_context(bool) - 默認為True.在使用data-parallelism 的情況下,將從多個設備收集梯度。一種Truevalue 表示我們應該合並收集的結果,以便它們看起來像來自單個執行器。

返回

輸入梯度。

返回類型

NDArray 列表或 NDArray 列表

獲取輸入的梯度,在先前的反向計算中計算。

如果 merge_multi_contextTrue ,則類似於 [grad1, grad2] 。否則,它就像 [[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]] 。所有輸出元素的類型均為 NDArray 。當 merge_multi_contextFalse 時,那些 NDArray 實例可能存在於不同的設備上。

例子

>>> # An example of getting input gradients.
>>> print mod.get_input_grads()[0].asnumpy()
[[[  1.10182791e-05   5.12257748e-06   4.01927764e-06   8.32566820e-06
    -1.59775993e-06   7.24269375e-06   7.28067835e-06  -1.65902311e-05
    5.46342608e-06   8.44196393e-07]
    ...]]

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.module.BaseModule.get_input_grads。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。