當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch hessian用法及代碼示例


本文簡要介紹python語言中 torch.autograd.functional.hessian 的用法。

用法:

torch.autograd.functional.hessian(func, inputs, create_graph=False, strict=False, vectorize=False)

參數

  • func(函數) -一個 Python 函數,它接受張量輸入並返回一個帶有單個元素的張量。

  • inputs(張量元組或者Tensor) -函數 func 的輸入。

  • create_graph(bool,可選的) -如果 True ,Hessian 將以可微分的方式計算。請注意,當 strictFalse 時,結果不能需要梯度或與輸入斷開連接。默認為 False

  • strict(bool,可選的) -如果 True ,當我們檢測到存在一個輸入使得所有輸出都獨立於它時,將引發錯誤。如果 False ,我們返回一個零張量作為所述輸入的粗麻布,這是預期的數學值。默認為 False

  • vectorize(bool,可選的) -此函數是實驗性的,請自行承擔使用風險。在計算 hessian 時,通常我們在 hessian 的每一行調用一次autograd.grad。如果此標誌是 True ,我們使用 vmap 原型函數作為後端來矢量化對 autograd.grad 的調用,因此我們隻調用它一次而不是每行一次。這應該會在許多用例中帶來性能改進,但是,由於此函數不完整,可能會出現性能懸崖。請使用torch._C._debug_only_display_vmap_fallback_warnings(True) 顯示任何性能警告,如果您的用例存在警告,請向我們提交問題。默認為 False

返回

如果隻有一個輸入,這將是一個包含輸入的 Hessian 的張量。如果它是一個元組,那麽 Hessian 將是一個元組的元組,其中 Hessian[i][j] 將包含第 i 輸入和第 j 輸入的 Hessian,其大小為第 i 輸入的大小之和加上 j th 輸入的大小。 Hessian[i][j] 將具有與相應的 i th 輸入相同的 dtype 和設備。

返回類型

Hessian(Tensor 或張量元組的元組)

計算給定標量函數的 Hessian 函數。

示例

>>> def pow_reducer(x):
...   return x.pow(3).sum()
>>> inputs = torch.rand(2, 2)
>>> hessian(pow_reducer, inputs)
tensor([[[[5.2265, 0.0000],
          [0.0000, 0.0000]],
         [[0.0000, 4.8221],
          [0.0000, 0.0000]]],
        [[[0.0000, 0.0000],
          [1.9456, 0.0000]],
         [[0.0000, 0.0000],
          [0.0000, 3.2550]]]])
>>> hessian(pow_reducer, inputs, create_graph=True)
tensor([[[[5.2265, 0.0000],
          [0.0000, 0.0000]],
         [[0.0000, 4.8221],
          [0.0000, 0.0000]]],
        [[[0.0000, 0.0000],
          [1.9456, 0.0000]],
         [[0.0000, 0.0000],
          [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
>>> def pow_adder_reducer(x, y):
...   return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2))
>>> hessian(pow_adder_reducer, inputs)
((tensor([[4., 0.],
          [0., 4.]]),
  tensor([[0., 0.],
          [0., 0.]])),
 (tensor([[0., 0.],
          [0., 0.]]),
  tensor([[6., 0.],
          [0., 6.]])))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.autograd.functional.hessian。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。