當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch set_default_dtype用法及代碼示例


本文簡要介紹python語言中 torch.set_default_dtype 的用法。

用法:

torch.set_default_dtype(d)

參數

d(torch.dtype) -要設為默認值的浮點 dtype。 torch.float32 或 torch.float64。

將默認浮點數據類型設置為 d 。支持 torch.float32 和 torch.float64 作為輸入。其他 dtype 可能會毫無怨言地被接受,但不受支持且不太可能按預期工作。

當PyTorch初始化時,其默認浮點數據類型為torch.float32,set_default_dtype(torch.float64)的目的是促進NumPy-like類型推斷。默認浮點數據類型用於:

  1. 隱式確定默認的複雜數據類型。當默認浮點類型為 float32 時,默認複數 dtype 為 complex64,當默認浮點類型為 float64 時,默認複數類型為 complex128。

  2. 推斷使用 Python 浮點數或複雜 Python 數字構造的張量的 dtype。請參閱下麵的示例。

  3. 確定 bool 和整數張量以及 Python 浮點數和複雜 Python 數字之間的類型提升的結果。

示例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype   # a new complex tensor
torch.complex128

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.set_default_dtype。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。