當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python mxnet.contrib.autograd.grad用法及代碼示例

用法:

mxnet.contrib.autograd.grad(func, argnum=None)

參數

  • func(a python function) - 前向(損失)函數。
  • argnum(an int or a list of int) - 計算梯度的參數索引。

返回

grad_func- 一個計算參數梯度的函數。

返回類型

一個python函數

返回計算參數梯度的函數。

例子

>>> # autograd supports dynamic graph which is changed
>>> # every instance
>>> def func(x):
>>>     r = random.randint(0, 1)
>>>     if r % 2:
>>>         return x**2
>>>     else:
>>>         return x/3
>>> # use `grad(func)` to get the gradient function
>>> for x in range(10):
>>>     grad_func = grad(func)
>>>     inputs = nd.array([[1, 2, 3], [4, 5, 6]])
>>>     grad_vals = grad_func(inputs)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.contrib.autograd.grad。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。