当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.contrib.autograd.grad用法及代码示例


用法:

mxnet.contrib.autograd.grad(func, argnum=None)

参数

  • func(a python function) - 前向(损失)函数。
  • argnum(an int or a list of int) - 计算梯度的参数索引。

返回

grad_func- 一个计算参数梯度的函数。

返回类型

一个python函数

返回计算参数梯度的函数。

例子

>>> # autograd supports dynamic graph which is changed
>>> # every instance
>>> def func(x):
>>>     r = random.randint(0, 1)
>>>     if r % 2:
>>>         return x**2
>>>     else:
>>>         return x/3
>>> # use `grad(func)` to get the gradient function
>>> for x in range(10):
>>>     grad_func = grad(func)
>>>     inputs = nd.array([[1, 2, 3], [4, 5, 6]])
>>>     grad_vals = grad_func(inputs)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.contrib.autograd.grad。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。