當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch assert_close用法及代碼示例


本文簡要介紹python語言中 torch.testing.assert_close 的用法。

用法:

torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=False, check_is_coalesced=True, msg=None)

參數

  • actual(任何) -實際輸入。

  • expected(任何) -預期輸入。

  • allow_subclasses(bool) -如果True(默認)並且除了 Python 標量,直接相關類型的輸入是允許的。否則需要類型相等。

  • rtol(可選的[float]) -相對容差。如果指定了atol,還必須指定。如果省略,則使用下表選擇基於dtype 的默認值。

  • atol(可選的[float]) -絕對的寬容。如果指定了rtol,還必須指定。如果省略,則使用下表選擇基於dtype 的默認值。

  • equal_nan(聯盟[bool,str]) -如果 True ,兩個 NaN 值將被視為相等。

  • check_device(bool) -如果 True(默認),則斷言相應的張量在同一個 device 上。如果禁用此檢查,則不同 device 上的張量會在比較之前移至 CPU。

  • check_dtype(bool) -如果 True(默認),則斷言相應的張量具有相同的 dtype 。如果禁用此檢查,則在比較之前,具有不同 dtype 的張量將被提升為公共 dtype(根據 torch.promote_types() )。

  • check_stride(bool) -如果 True 和相應的張量是跨步的,則斷言它們具有相同的跨步。

  • check_is_coalesced(bool) -如果 True(默認)和相應的張量是稀疏 COO,則檢查 actualexpected 是否已合並或未合並。如果禁用此檢查,則張量在比較之前會被 coalesce() 編輯。

  • msg(可選的[聯盟[str,可調用[[Tensor,Tensor,診斷],str]]]) - 如果相應張量的值不匹配,則使用可選的錯誤消息。可以作為 callable 傳遞,在這種情況下,它將使用不匹配的張量和關於不匹配的診斷名稱空間來調用。詳情見下文。

拋出

斷言actualexpected 接近。

如果 actualexpected 是跨步的、非量化的、實值的和有限的,則它們被認為是接近的,如果

並且它們具有相同的 device (如果 check_deviceTrue )、相同的 dtype (如果 check_dtypeTrue )和相同的步幅(如果 check_strideTrue )。非有限值(-infinf)僅當且僅當它們相等時才被視為接近。僅當 equal_nanTrue 時,NaN 才被視為彼此相等。

如果 actualexpected 是稀疏的(具有 COO 或 CSR 布局),則單獨檢查它們的跨步成員。索引,即 COO 的 indices 或 CSR 布局的 crow_indicescol_indices,始終檢查是否相等,而根據上述定義檢查值是否接近。僅當兩者都合並或未合並時(如果 check_is_coalescedTrue ),稀疏 COO 張量才被視為接近。

如果actualexpected被量化,如果它們具有相同的 qscheme() 並且根據上麵的定義 dequantize() 的結果是接近的,則它們被認為是接近的。

actualexpected 可以是 Tensor 或任何 tensor-or-scalar-likes,從中可以使用 torch.Tensor 構造 torch.as_tensor() 。除了 Python 標量,輸入類型必須直接相關。此外,actualexpected 可以是 Sequence Mapping 在這種情況下,如果它們的結構匹配並且根據上述定義,它們的所有元素都被認為是接近的,則它們被認為是接近的。

注意

Python 標量是類型關係要求的一個例外,因為它們的 type() ,即 int float complex 等價於 tensor-like 的 dtype。因此,可以檢查不同類型的 Python 標量,但需要將 check_dtype 設置為 False

下表顯示了不同 dtype 的默認 rtolatol。如果 dtype 不匹配,則使用兩個公差的最大值。

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

other

0.0

0.0

如果其可調用對象具有以下屬性,則將傳遞給msg 的診斷名稱空間:

  • number_of_elements (int):每個被比較的張量中的元素數。

  • total_mismatches (int):不匹配的總數。

  • max_abs_diff (Union[int, float]):輸入的最大絕對差。

  • max_abs_diff_idx (Union[int, Tuple[int, ...]]):最大絕對差的索引。

  • atol(浮點數):允許的絕對公差。

  • max_rel_diff (Union[int, float]):輸入的最大相對差異。

  • max_rel_diff_idx (Union[int, Tuple[int, ...]]):最大相對差異的索引。

  • rtol(浮點數):允許的相對公差。

對於max_abs_diffmax_rel_diff,類型取決於輸入的dtype

注意

assert_close() 具有高度可配置性,具有嚴格的默認設置。鼓勵用戶 partial() 它以適合他們的用例。例如,如果需要進行相等性檢查,則可以定義一個 assert_equal,默認情況下對每個 dtype 使用零容差:

>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: Scalars are not equal!

Absolute difference: 8.999999703829253e-10
Relative difference: 8.999999583666371

例子

>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # By default, directly related instances can be compared
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # This check can be made more strict with allow_subclasses=False
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
AssertionError: Except for Python scalars, type equality is required if
allow_subclasses=False, but got <class 'torch.nn.parameter.Parameter'> and
<class 'torch.Tensor'> instead.
>>> # If the inputs are not directly related, they are never considered close
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
AssertionError: Except for Python scalars, input types need to be directly
related, but got <class 'numpy.ndarray'> and <class 'torch.Tensor'> instead.
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of
>>> # their type if check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: Scalars are not close!

Absolute difference: nan (up to 1e-05 allowed)
Relative difference: nan (up to 1.3e-06 allowed)
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default mismatch message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # The error message can also created at runtime by passing a callable.
>>> def custom_msg(actual, expected, diagnostics):
...     ratio = diagnostics.total_mismatches / diagnostics.number_of_elements
...     return (
...         f"Argh, we found {diagnostics.total_mismatches} mismatches! "
...         f"That is {ratio:.1%}!"
...     )
>>> torch.testing.assert_close(actual, expected, msg=custom_msg)
Traceback (most recent call last):
...
AssertionError: Argh, we found 2 mismatches! That is 66.7%!

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.testing.assert_close。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。