当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch hessian用法及代码示例


本文简要介绍python语言中 torch.autograd.functional.hessian 的用法。

用法:

torch.autograd.functional.hessian(func, inputs, create_graph=False, strict=False, vectorize=False)

参数

  • func(函数) -一个 Python 函数,它接受张量输入并返回一个带有单个元素的张量。

  • inputs(张量元组或者Tensor) -函数 func 的输入。

  • create_graph(bool,可选的) -如果 True ,Hessian 将以可微分的方式计算。请注意,当 strictFalse 时,结果不能需要梯度或与输入断开连接。默认为 False

  • strict(bool,可选的) -如果 True ,当我们检测到存在一个输入使得所有输出都独立于它时,将引发错误。如果 False ,我们返回一个零张量作为所述输入的粗麻布,这是预期的数学值。默认为 False

  • vectorize(bool,可选的) -此函数是实验性的,请自行承担使用风险。在计算 hessian 时,通常我们在 hessian 的每一行调用一次autograd.grad。如果此标志是 True ,我们使用 vmap 原型函数作为后端来矢量化对 autograd.grad 的调用,因此我们只调用它一次而不是每行一次。这应该会在许多用例中带来性能改进,但是,由于此函数不完整,可能会出现性能悬崖。请使用torch._C._debug_only_display_vmap_fallback_warnings(True) 显示任何性能警告,如果您的用例存在警告,请向我们提交问题。默认为 False

返回

如果只有一个输入,这将是一个包含输入的 Hessian 的张量。如果它是一个元组,那么 Hessian 将是一个元组的元组,其中 Hessian[i][j] 将包含第 i 输入和第 j 输入的 Hessian,其大小为第 i 输入的大小之和加上 j th 输入的大小。 Hessian[i][j] 将具有与相应的 i th 输入相同的 dtype 和设备。

返回类型

Hessian(Tensor 或张量元组的元组)

计算给定标量函数的 Hessian 函数。

示例

>>> def pow_reducer(x):
...   return x.pow(3).sum()
>>> inputs = torch.rand(2, 2)
>>> hessian(pow_reducer, inputs)
tensor([[[[5.2265, 0.0000],
          [0.0000, 0.0000]],
         [[0.0000, 4.8221],
          [0.0000, 0.0000]]],
        [[[0.0000, 0.0000],
          [1.9456, 0.0000]],
         [[0.0000, 0.0000],
          [0.0000, 3.2550]]]])
>>> hessian(pow_reducer, inputs, create_graph=True)
tensor([[[[5.2265, 0.0000],
          [0.0000, 0.0000]],
         [[0.0000, 4.8221],
          [0.0000, 0.0000]]],
        [[[0.0000, 0.0000],
          [1.9456, 0.0000]],
         [[0.0000, 0.0000],
          [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
>>> def pow_adder_reducer(x, y):
...   return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2))
>>> hessian(pow_adder_reducer, inputs)
((tensor([[4., 0.],
          [0., 4.]]),
  tensor([[0., 0.],
          [0., 0.]])),
 (tensor([[0., 0.],
          [0., 0.]]),
  tensor([[6., 0.],
          [0., 6.]])))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.autograd.functional.hessian。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。