|
| 1 | +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. |
| 2 | +""" |
| 3 | +import itertools |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | + |
| 8 | +from array_api_compat import torch as xp |
| 9 | + |
| 10 | + |
| 11 | +class TestResultType: |
| 12 | + def test_empty(self): |
| 13 | + with pytest.raises(ValueError): |
| 14 | + xp.result_type() |
| 15 | + |
| 16 | + def test_one_arg(self): |
| 17 | + for x in [1, 1.0, 1j, '...', None]: |
| 18 | + with pytest.raises((ValueError, AttributeError)): |
| 19 | + xp.result_type(x) |
| 20 | + |
| 21 | + for x in [xp.float32, xp.int64, torch.complex64]: |
| 22 | + assert xp.result_type(x) == x |
| 23 | + |
| 24 | + for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: |
| 25 | + assert xp.result_type(x) == x.dtype |
| 26 | + |
| 27 | + def test_two_args(self): |
| 28 | + # Only include here things "unspecified" in the spec |
| 29 | + |
| 30 | + # scalar, tensor or tensor,tensor |
| 31 | + for x, y in [ |
| 32 | + (1., 1j), |
| 33 | + (1j, xp.arange(3)), |
| 34 | + (True, xp.asarray(3.)), |
| 35 | + (xp.ones(3) == 1, 1j*xp.ones(3)), |
| 36 | + ]: |
| 37 | + assert xp.result_type(x, y) == torch.result_type(x, y) |
| 38 | + |
| 39 | + # dtype, scalar |
| 40 | + for x, y in [ |
| 41 | + (1j, xp.int64), |
| 42 | + (True, xp.float64), |
| 43 | + ]: |
| 44 | + assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) |
| 45 | + |
| 46 | + # dtype, dtype |
| 47 | + for x, y in [ |
| 48 | + (xp.bool, xp.complex64) |
| 49 | + ]: |
| 50 | + xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) |
| 51 | + assert xp.result_type(x, y) == torch.result_type(xt, yt) |
| 52 | + |
| 53 | + def test_multi_arg(self): |
| 54 | + torch.set_default_dtype(torch.float32) |
| 55 | + |
| 56 | + args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] |
| 57 | + assert xp.result_type(*args) == torch.float16 |
| 58 | + |
| 59 | + args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] |
| 60 | + assert xp.result_type(*args) == xp.complex64 |
| 61 | + |
| 62 | + args = [1, 2, 3j, xp.float64, 4, 5, 6] |
| 63 | + assert xp.result_type(*args) == xp.complex128 |
| 64 | + |
| 65 | + args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] |
| 66 | + assert xp.result_type(*args) == xp.complex128 |
| 67 | + |
| 68 | + i64 = xp.ones(1, dtype=xp.int64) |
| 69 | + f16 = xp.ones(1, dtype=xp.float16) |
| 70 | + for i in itertools.permutations([i64, f16, 1.0, 1.0]): |
| 71 | + assert xp.result_type(*i) == xp.float16, f"{i}" |
| 72 | + |
| 73 | + with pytest.raises(ValueError): |
| 74 | + xp.result_type(1, 2, 3, 4) |
| 75 | + |
| 76 | + |
| 77 | + @pytest.mark.parametrize("default_dt", ['float32', 'float64']) |
| 78 | + @pytest.mark.parametrize("dtype_a", |
| 79 | + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) |
| 80 | + ) |
| 81 | + @pytest.mark.parametrize("dtype_b", |
| 82 | + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) |
| 83 | + ) |
| 84 | + def test_gh_273(self, default_dt, dtype_a, dtype_b): |
| 85 | + # Regression test for https://github.com/data-apis/array-api-compat/issues/273 |
| 86 | + |
| 87 | + try: |
| 88 | + prev_default = torch.get_default_dtype() |
| 89 | + default_dtype = getattr(torch, default_dt) |
| 90 | + torch.set_default_dtype(default_dtype) |
| 91 | + |
| 92 | + a = xp.asarray([2, 1], dtype=dtype_a) |
| 93 | + b = xp.asarray([1, -1], dtype=dtype_b) |
| 94 | + dtype_1 = xp.result_type(a, b, 1.0) |
| 95 | + dtype_2 = xp.result_type(b, a, 1.0) |
| 96 | + assert dtype_1 == dtype_2 |
| 97 | + finally: |
| 98 | + torch.set_default_dtype(prev_default) |
0 commit comments