torch.testing¶
Warning
This module is in a PROTOTYPE state. New functions are still being added, and the available functions may change in future PyTorch releases. We are actively looking for feedback for UI/UX improvements or missing functionalities.
-
torch.testing.
assert_close
(actual, expected, *, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=True, check_is_coalesced=True, msg=None)[source]¶ Asserts that
actual
andexpected
are close.If
actual
andexpected
are strided, real-valued, and finite, they are considered close ifand they have the same
device
(ifcheck_device
isTrue
), samedtype
(ifcheck_dtype
isTrue
), and the same stride (ifcheck_stride
isTrue
). Non-finite values (-inf
andinf
) are only considered close if and only if they are equal.NaN
’s are only considered equal to each other ifequal_nan
isTrue
.If
actual
andexpected
are complex-valued, they are considered close if both their real and imaginary components are considered close according to the definition above.If
actual
andexpected
are sparse (either having COO or CSR layout), their strided members are checked individually. Indices, namelyindices
for COO orcrow_indices
andcol_indices
for CSR layout, are always checked for equality whereas the values are checked for closeness according to the definition above. Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (ifcheck_is_coalesced
isTrue
).actual
andexpected
can beTensor
’s or any array-or-scalar-like of the same type, from whichtorch.Tensor
’s can be constructed withtorch.as_tensor()
. In addition,actual
andexpected
can beSequence
’s orMapping
’s in which case they are considered close if their structure matches and all their elements are considered close according to the above definition.- Parameters
actual (Any) – Actual input.
expected (Any) – Expected input.
rtol (Optional[float]) – Relative tolerance. If specified
atol
must also be specified. If omitted, default values based on thedtype
are selected with the below table.atol (Optional[float]) – Absolute tolerance. If specified
rtol
must also be specified. If omitted, default values based on thedtype
are selected with the below table.equal_nan (Union[bool, str]) – If
True
, twoNaN
values will be considered equal. If"relaxed"
, complex values are considered asNaN
if either the real or imaginary component isNaN
.check_device (bool) – If
True
(default), asserts that corresponding tensors are on the samedevice
. If this check is disabled, tensors on differentdevice
’s are moved to the CPU before being compared.check_dtype (bool) – If
True
(default), asserts that corresponding tensors have the samedtype
. If this check is disabled, tensors with differentdtype
’s are promoted to a commondtype
(according totorch.promote_types()
) before being compared.check_stride (bool) – If
True
(default) and corresponding tensors are strided, asserts that they have the same stride.check_is_coalesced (bool) – If
True
(default) and corresponding tensors are sparse COO, checks that bothactual
andexpected
are either coalesced or uncoalesced. If this check is disabled, tensors arecoalesce()
’ed before being compared.msg (Optional[Union[str, Callable[[Tensor, Tensor, DiagnosticInfo], str]]]) – Optional error message to use if the values of corresponding tensors mismatch. Can be passed as callable in which case it will be called with the mismatching tensors and a namespace of diagnostic info about the mismatches. See below for details.
- Raises
UsageError – If a
torch.Tensor
can’t be constructed from an array-or-scalar-like.UsageError – If any tensor is quantized. This is a temporary restriction and will be relaxed in the future.
UsageError – If only
rtol
oratol
is specified.AssertionError – If corresponding array-likes have different types.
AssertionError – If the inputs are
Sequence
’s, but their length does not match.AssertionError – If the inputs are
Mapping
’s, but their set of keys do not match.AssertionError – If corresponding tensors do not have the same
shape
.AssertionError – If corresponding tensors do not have the same
layout
.AssertionError – If
check_device
, but corresponding tensors are not on the samedevice
.AssertionError – If
check_dtype
, but corresponding tensors do not have the samedtype
.AssertionError – If
check_stride
, but corresponding strided tensors do not have the same stride.AssertionError – If
check_is_coalesced
, but corresponding sparse COO tensors are not both either coalesced or uncoalesced.AssertionError – If the values of corresponding tensors are not close.
The following table displays the default
rtol
andatol
for differentdtype
’s. Note that thedtype
refers to the promoted type in caseactual
andexpected
do not have the samedtype
.dtype
rtol
atol
float16
1e-3
1e-5
bfloat16
1.6e-2
1e-5
float32
1.3e-6
1e-5
float64
1e-7
1e-7
complex32
1e-3
1e-5
complex64
1.3e-6
1e-5
complex128
1e-7
1e-7
other
0.0
0.0
The namespace of diagnostic information that will be passed to
msg
if its a callable has the following attributes:number_of_elements
(int): Number of elements in each tensor being compared.total_mismatches
(int): Total number of mismatches.mismatch_ratio
(float): Total mismatches divided by number of elements.max_abs_diff
(Union[int, float]): Greatest absolute difference of the inputs.max_abs_diff_idx
(Union[int, Tuple[int, …]]): Index of greatest absolute difference.max_rel_diff
(Union[int, float]): Greatest relative difference of the inputs.max_rel_diff_idx
(Union[int, Tuple[int, …]]): Index of greatest relative difference.
For
max_abs_diff
andmax_rel_diff
the type depends on thedtype
of the inputs.Note
assert_close()
is highly configurable with strict default settings. Users are encouraged topartial()
it to fit their use case. For example, if an equality check is needed, one might define anassert_equal
that uses zero tolrances for everydtype
by default:>>> import functools >>> import torch >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) AssertionError: Tensors are not close! Mismatched elements: 1 / 1 (100.0%) Greatest absolute difference: 8.999999703829253e-10 at 0 (up to 0 allowed) Greatest relative difference: 8.999999583666371 at 0 (up to 0 allowed)
Examples
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)
>>> # Different input types are never considered close. >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.numpy() >>> torch.testing.assert_close(actual, expected) AssertionError: Except for scalars, type equality is required, but got <class 'numpy.ndarray'> and <class 'torch.Tensor'> instead. >>> # Scalars of different types are an exception and can be compared with >>> # check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) AssertionError: Tensors are not close! >>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> # If equal_nan=True, the real and imaginary NaN's of complex inputs have to match. >>> expected = torch.tensor(complex(float("NaN"), 0)) >>> actual = torch.tensor(complex(0, float("NaN"))) >>> torch.testing.assert_close(actual, expected, equal_nan=True) AssertionError: Tensors are not close! >>> # If equal_nan="relaxed", however, then complex numbers are treated as NaN if any >>> # of the real or imaginary component is NaN. >>> torch.testing.assert_close(actual, expected, equal_nan="relaxed")
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default mismatch message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") AssertionError: Argh, the tensors are not close! >>> # The error message can also created at runtime by passing a callable. >>> def custom_msg(actual, expected, diagnostic_info): ... return ( ... f"Argh, we found {diagnostic_info.total_mismatches} mismatches! " ... f"That is {diagnostic_info.mismatch_ratio:.1%}!" ... ) >>> torch.testing.assert_close(actual, expected, msg=custom_msg) AssertionError: Argh, we found 2 mismatches! That is 66.7%!