本文簡要介紹python語言中 torch.testing.assert_close
的用法。
用法:
torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=False, check_is_coalesced=True, msg=None)
actual(任何) -實際輸入。
expected(任何) -預期輸入。
allow_subclasses(bool) -如果
True
(默認)並且除了 Python 標量,直接相關類型的輸入是允許的。否則需要類型相等。rtol(可選的[float]) -相對容差。如果指定了
atol
,還必須指定。如果省略,則使用下表選擇基於dtype
的默認值。atol(可選的[float]) -絕對的寬容。如果指定了
rtol
,還必須指定。如果省略,則使用下表選擇基於dtype
的默認值。check_device(bool) -如果
True
(默認),則斷言相應的張量在同一個device
上。如果禁用此檢查,則不同device
上的張量會在比較之前移至 CPU。check_dtype(bool) -如果
True
(默認),則斷言相應的張量具有相同的dtype
。如果禁用此檢查,則在比較之前,具有不同dtype
的張量將被提升為公共dtype
(根據torch.promote_types()
)。check_stride(bool) -如果
True
和相應的張量是跨步的,則斷言它們具有相同的跨步。check_is_coalesced(bool) -如果
True
(默認)和相應的張量是稀疏 COO,則檢查actual
和expected
是否已合並或未合並。如果禁用此檢查,則張量在比較之前會被coalesce()
編輯。msg(可選的[聯盟[str,可調用[[Tensor,Tensor,診斷],str]]]) - 如果相應張量的值不匹配,則使用可選的錯誤消息。可以作為 callable 傳遞,在這種情況下,它將使用不匹配的張量和關於不匹配的診斷名稱空間來調用。詳情見下文。
ValueError - 如果沒有
torch.Tensor
可以從輸入構造。ValueError - 如果僅指定
rtol
或atol
。AssertionError - 如果相應的輸入不是 Python 標量並且不直接相關。
AssertionError - 如果
allow_subclasses
是False
,但相應的輸入不是 Python 標量並且具有不同的類型。AssertionError - 如果輸入是
Sequence
的,但它們的長度不匹配。AssertionError - 如果輸入是
Mapping
的,但它們的鍵集不匹配。AssertionError - 如果相應的張量沒有相同的
shape
。AssertionError - 如果相應的張量沒有相同的
layout
。AssertionError - 如果相應的張量被量化,但具有不同的
qscheme()
。AssertionError - 如果
check_device
是True
,但對應的張量不在同一個device
上。AssertionError - 如果
check_dtype
是True
,但對應的張量沒有相同的dtype
。AssertionError - 如果
check_stride
是True
,但對應的跨步張量沒有相同的跨度。AssertionError - 如果
check_is_coalesced
是True
,但相應的稀疏 COO 張量既不是合並的也不是未合並的。AssertionError - 如果對應張量的值根據上麵的定義不接近。
斷言
actual
和expected
接近。如果
actual
和expected
是跨步的、非量化的、實值的和有限的,則它們被認為是接近的,如果並且它們具有相同的
device
(如果check_device
是True
)、相同的dtype
(如果check_dtype
是True
)和相同的步幅(如果check_stride
是True
)。非有限值(-inf
和inf
)僅當且僅當它們相等時才被視為接近。僅當equal_nan
為True
時,NaN
才被視為彼此相等。如果
actual
和expected
是稀疏的(具有 COO 或 CSR 布局),則單獨檢查它們的跨步成員。索引,即 COO 的indices
或 CSR 布局的crow_indices
和col_indices
,始終檢查是否相等,而根據上述定義檢查值是否接近。僅當兩者都合並或未合並時(如果check_is_coalesced
為True
),稀疏 COO 張量才被視為接近。如果
actual
和expected
被量化,如果它們具有相同的qscheme()
並且根據上麵的定義dequantize()
的結果是接近的,則它們被認為是接近的。actual
和expected
可以是Tensor
或任何 tensor-or-scalar-likes,從中可以使用torch.Tensor
構造torch.as_tensor()
。除了 Python 標量,輸入類型必須直接相關。此外,actual
和expected
可以是Sequence
或Mapping
在這種情況下,如果它們的結構匹配並且根據上述定義,它們的所有元素都被認為是接近的,則它們被認為是接近的。注意
Python 標量是類型關係要求的一個例外,因為它們的
type()
,即int
、float
和complex
等價於 tensor-like 的dtype
。因此,可以檢查不同類型的 Python 標量,但需要將check_dtype
設置為False
。下表顯示了不同
dtype
的默認rtol
和atol
。如果dtype
不匹配,則使用兩個公差的最大值。dtype
rtol
atol
float16
1e-3
1e-5
bfloat16
1.6e-2
1e-5
float32
1.3e-6
1e-5
float64
1e-7
1e-7
complex32
1e-3
1e-5
complex64
1.3e-6
1e-5
complex128
1e-7
1e-7
other
0.0
0.0
如果其可調用對象具有以下屬性,則將傳遞給
msg
的診斷名稱空間:number_of_elements
(int):每個被比較的張量中的元素數。total_mismatches
(int):不匹配的總數。max_abs_diff
(Union[int, float]):輸入的最大絕對差。max_abs_diff_idx
(Union[int, Tuple[int, ...]]):最大絕對差的索引。atol
(浮點數):允許的絕對公差。max_rel_diff
(Union[int, float]):輸入的最大相對差異。max_rel_diff_idx
(Union[int, Tuple[int, ...]]):最大相對差異的索引。rtol
(浮點數):允許的相對公差。
對於
max_abs_diff
和max_rel_diff
,類型取決於輸入的dtype
。注意
assert_close()
具有高度可配置性,具有嚴格的默認設置。鼓勵用戶partial()
它以適合他們的用例。例如,如果需要進行相等性檢查,則可以定義一個assert_equal
,默認情況下對每個dtype
使用零容差:>>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Absolute difference: 8.999999703829253e-10 Relative difference: 8.999999583666371
例子
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... AssertionError: Except for Python scalars, type equality is required if allow_subclasses=False, but got <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'> instead. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... AssertionError: Except for Python scalars, input types need to be directly related, but got <class 'numpy.ndarray'> and <class 'torch.Tensor'> instead. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default mismatch message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # The error message can also created at runtime by passing a callable. >>> def custom_msg(actual, expected, diagnostics): ... ratio = diagnostics.total_mismatches / diagnostics.number_of_elements ... return ( ... f"Argh, we found {diagnostics.total_mismatches} mismatches! " ... f"That is {ratio:.1%}!" ... ) >>> torch.testing.assert_close(actual, expected, msg=custom_msg) Traceback (most recent call last): ... AssertionError: Argh, we found 2 mismatches! That is 66.7%!
參數:
拋出:
相關用法
- Python PyTorch async_execution用法及代碼示例
- Python PyTorch as_strided用法及代碼示例
- Python PyTorch asin用法及代碼示例
- Python PyTorch asinh用法及代碼示例
- Python PyTorch as_tensor用法及代碼示例
- Python PyTorch argsort用法及代碼示例
- Python PyTorch addmm用法及代碼示例
- Python PyTorch addmv用法及代碼示例
- Python PyTorch apply_effects_tensor用法及代碼示例
- Python PyTorch angle用法及代碼示例
- Python PyTorch all_reduce用法及代碼示例
- Python PyTorch atanh用法及代碼示例
- Python PyTorch annotate用法及代碼示例
- Python PyTorch argmax用法及代碼示例
- Python PyTorch atan用法及代碼示例
- Python PyTorch acos用法及代碼示例
- Python PyTorch all_gather用法及代碼示例
- Python PyTorch avg_pool1d用法及代碼示例
- Python PyTorch allreduce_hook用法及代碼示例
- Python PyTorch argmin用法及代碼示例
- Python PyTorch any用法及代碼示例
- Python PyTorch all_to_all用法及代碼示例
- Python PyTorch add用法及代碼示例
- Python PyTorch addcdiv用法及代碼示例
- Python PyTorch acosh用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.testing.assert_close。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。