本文简要介绍python语言中 torch.testing.assert_close 的用法。
用法:
torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=False, check_is_coalesced=True, msg=None)actual(任何) -实际输入。
expected(任何) -预期输入。
allow_subclasses(bool) -如果
True(默认)并且除了 Python 标量,直接相关类型的输入是允许的。否则需要类型相等。rtol(可选的[float]) -相对容差。如果指定了
atol,还必须指定。如果省略,则使用下表选择基于dtype的默认值。atol(可选的[float]) -绝对的宽容。如果指定了
rtol,还必须指定。如果省略,则使用下表选择基于dtype的默认值。check_device(bool) -如果
True(默认),则断言相应的张量在同一个device上。如果禁用此检查,则不同device上的张量会在比较之前移至 CPU。check_dtype(bool) -如果
True(默认),则断言相应的张量具有相同的dtype。如果禁用此检查,则在比较之前,具有不同dtype的张量将被提升为公共dtype(根据torch.promote_types())。check_stride(bool) -如果
True和相应的张量是跨步的,则断言它们具有相同的跨步。check_is_coalesced(bool) -如果
True(默认)和相应的张量是稀疏 COO,则检查actual和expected是否已合并或未合并。如果禁用此检查,则张量在比较之前会被coalesce()编辑。msg(可选的[联盟[str,可调用[[Tensor,Tensor,诊断],str]]]) - 如果相应张量的值不匹配,则使用可选的错误消息。可以作为 callable 传递,在这种情况下,它将使用不匹配的张量和关于不匹配的诊断名称空间来调用。详情见下文。
ValueError - 如果没有
torch.Tensor可以从输入构造。ValueError - 如果仅指定
rtol或atol。AssertionError - 如果相应的输入不是 Python 标量并且不直接相关。
AssertionError - 如果
allow_subclasses是False,但相应的输入不是 Python 标量并且具有不同的类型。AssertionError - 如果输入是
Sequence的,但它们的长度不匹配。AssertionError - 如果输入是
Mapping的,但它们的键集不匹配。AssertionError - 如果相应的张量没有相同的
shape。AssertionError - 如果相应的张量没有相同的
layout。AssertionError - 如果相应的张量被量化,但具有不同的
qscheme()。AssertionError - 如果
check_device是True,但对应的张量不在同一个device上。AssertionError - 如果
check_dtype是True,但对应的张量没有相同的dtype。AssertionError - 如果
check_stride是True,但对应的跨步张量没有相同的跨度。AssertionError - 如果
check_is_coalesced是True,但相应的稀疏 COO 张量既不是合并的也不是未合并的。AssertionError - 如果对应张量的值根据上面的定义不接近。
断言
actual和expected接近。如果
actual和expected是跨步的、非量化的、实值的和有限的,则它们被认为是接近的,如果并且它们具有相同的
device(如果check_device是True)、相同的dtype(如果check_dtype是True)和相同的步幅(如果check_stride是True)。非有限值(-inf和inf)仅当且仅当它们相等时才被视为接近。仅当equal_nan为True时,NaN才被视为彼此相等。如果
actual和expected是稀疏的(具有 COO 或 CSR 布局),则单独检查它们的跨步成员。索引,即 COO 的indices或 CSR 布局的crow_indices和col_indices,始终检查是否相等,而根据上述定义检查值是否接近。仅当两者都合并或未合并时(如果check_is_coalesced为True),稀疏 COO 张量才被视为接近。如果
actual和expected被量化,如果它们具有相同的qscheme()并且根据上面的定义dequantize()的结果是接近的,则它们被认为是接近的。actual和expected可以是Tensor或任何 tensor-or-scalar-likes,从中可以使用torch.Tensor构造torch.as_tensor()。除了 Python 标量,输入类型必须直接相关。此外,actual和expected可以是Sequence或Mapping在这种情况下,如果它们的结构匹配并且根据上述定义,它们的所有元素都被认为是接近的,则它们被认为是接近的。注意
Python 标量是类型关系要求的一个例外,因为它们的
type(),即int、float和complex等价于 tensor-like 的dtype。因此,可以检查不同类型的 Python 标量,但需要将check_dtype设置为False。下表显示了不同
dtype的默认rtol和atol。如果dtype不匹配,则使用两个公差的最大值。dtypertolatolfloat161e-31e-5bfloat161.6e-21e-5float321.3e-61e-5float641e-71e-7complex321e-31e-5complex641.3e-61e-5complex1281e-71e-7other
0.00.0如果其可调用对象具有以下属性,则将传递给
msg的诊断名称空间:number_of_elements(int):每个被比较的张量中的元素数。total_mismatches(int):不匹配的总数。max_abs_diff(Union[int, float]):输入的最大绝对差。max_abs_diff_idx(Union[int, Tuple[int, ...]]):最大绝对差的索引。atol(浮点数):允许的绝对公差。max_rel_diff(Union[int, float]):输入的最大相对差异。max_rel_diff_idx(Union[int, Tuple[int, ...]]):最大相对差异的索引。rtol(浮点数):允许的相对公差。
对于
max_abs_diff和max_rel_diff,类型取决于输入的dtype。注意
assert_close()具有高度可配置性,具有严格的默认设置。鼓励用户partial()它以适合他们的用例。例如,如果需要进行相等性检查,则可以定义一个assert_equal,默认情况下对每个dtype使用零容差:>>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Absolute difference: 8.999999703829253e-10 Relative difference: 8.999999583666371例子
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... AssertionError: Except for Python scalars, type equality is required if allow_subclasses=False, but got <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'> instead. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... AssertionError: Except for Python scalars, input types need to be directly related, but got <class 'numpy.ndarray'> and <class 'torch.Tensor'> instead. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True)>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default mismatch message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # The error message can also created at runtime by passing a callable. >>> def custom_msg(actual, expected, diagnostics): ... ratio = diagnostics.total_mismatches / diagnostics.number_of_elements ... return ( ... f"Argh, we found {diagnostics.total_mismatches} mismatches! " ... f"That is {ratio:.1%}!" ... ) >>> torch.testing.assert_close(actual, expected, msg=custom_msg) Traceback (most recent call last): ... AssertionError: Argh, we found 2 mismatches! That is 66.7%!
参数:
抛出:
相关用法
- Python PyTorch async_execution用法及代码示例
- Python PyTorch as_strided用法及代码示例
- Python PyTorch asin用法及代码示例
- Python PyTorch asinh用法及代码示例
- Python PyTorch as_tensor用法及代码示例
- Python PyTorch argsort用法及代码示例
- Python PyTorch addmm用法及代码示例
- Python PyTorch addmv用法及代码示例
- Python PyTorch apply_effects_tensor用法及代码示例
- Python PyTorch angle用法及代码示例
- Python PyTorch all_reduce用法及代码示例
- Python PyTorch atanh用法及代码示例
- Python PyTorch annotate用法及代码示例
- Python PyTorch argmax用法及代码示例
- Python PyTorch atan用法及代码示例
- Python PyTorch acos用法及代码示例
- Python PyTorch all_gather用法及代码示例
- Python PyTorch avg_pool1d用法及代码示例
- Python PyTorch allreduce_hook用法及代码示例
- Python PyTorch argmin用法及代码示例
- Python PyTorch any用法及代码示例
- Python PyTorch all_to_all用法及代码示例
- Python PyTorch add用法及代码示例
- Python PyTorch addcdiv用法及代码示例
- Python PyTorch acosh用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.testing.assert_close。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
