|
1 | 1 | """
|
2 |
| -https://github.com/data-apis/array-api/blob/master/spec/API_specification/type_promotion.md |
| 2 | +https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html |
3 | 3 | """
|
4 | 4 |
|
5 | 5 | import pytest
|
6 | 6 |
|
7 | 7 | from hypothesis import given, example
|
| 8 | +from hypothesis.strategies import from_type, data |
8 | 9 |
|
9 | 10 | from .hypothesis_helpers import shapes
|
10 | 11 | from .pytest_helpers import nargs
|
| 12 | +from .array_helpers import assert_exactly_equal |
11 | 13 |
|
12 | 14 | from .function_stubs import elementwise_functions
|
13 | 15 | from ._array_module import (ones, int8, int16, int32, int64, uint8,
|
|
94 | 96 | **float_promotion_table,
|
95 | 97 | }
|
96 | 98 |
|
| 99 | + |
| 100 | +binary_operators = { |
| 101 | + '__add__': '+', |
| 102 | + '__and__': '&', |
| 103 | + '__eq__': '==', |
| 104 | + '__floordiv__': '//', |
| 105 | + '__ge__': '>=', |
| 106 | + '__gt__': '>', |
| 107 | + '__le__': '<=', |
| 108 | + '__lshift__': '<<', |
| 109 | + '__lt__': '<', |
| 110 | + '__matmul__': '@', |
| 111 | + '__mod__': '%', |
| 112 | + '__mul__': '*', |
| 113 | + '__ne__': '!=', |
| 114 | + '__or__': '|', |
| 115 | + '__pow__': '**', |
| 116 | + '__rshift__': '>>', |
| 117 | + '__sub__': '-', |
| 118 | + '__truediv__': '/', |
| 119 | + '__xor__': '^', |
| 120 | +} |
| 121 | + |
| 122 | +unary_operators = { |
| 123 | + '__invert__': '~', |
| 124 | + '__neg__': '-', |
| 125 | + '__pos__': '+', |
| 126 | +} |
| 127 | + |
| 128 | +dtypes_to_scalar = { |
| 129 | + _array_module.bool: bool, |
| 130 | + _array_module.int8: int, |
| 131 | + _array_module.int16: int, |
| 132 | + _array_module.int32: int, |
| 133 | + _array_module.int64: int, |
| 134 | + _array_module.uint8: int, |
| 135 | + _array_module.uint16: int, |
| 136 | + _array_module.uint32: int, |
| 137 | + _array_module.uint64: int, |
| 138 | + _array_module.float32: float, |
| 139 | + _array_module.float64: float, |
| 140 | +} |
| 141 | + |
| 142 | +scalar_to_dtype = {s: [d for d, _s in dtypes_to_scalar.items() if _s == s] for |
| 143 | + s in dtypes_to_scalar.values()} |
| 144 | + |
97 | 145 | # TODO: Extend this to all functions (not just elementwise), and handle
|
98 | 146 | # functions that take more than 2 args
|
99 | 147 | @pytest.mark.parametrize('func_name', [i for i in
|
@@ -123,6 +171,43 @@ def test_promotion(func_name, shape, dtypes):
|
123 | 171 |
|
124 | 172 | assert res.dtype == res_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape})"
|
125 | 173 |
|
| 174 | +@pytest.mark.parametrize('binary_op', sorted(set(binary_operators.values()) - {'@'})) |
| 175 | +@pytest.mark.parametrize('scalar_type,dtype', [(s, d) for s in scalar_to_dtype |
| 176 | + for d in scalar_to_dtype[s]]) |
| 177 | +@given(shape=shapes, scalars=data()) |
| 178 | +def test_operator_scalar_promotion(binary_op, scalar_type, dtype, shape, scalars): |
| 179 | + """ |
| 180 | + See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars |
| 181 | + """ |
| 182 | + if binary_op == '@': |
| 183 | + pytest.skip("matmul (@) is not supported for scalars") |
| 184 | + a = ones(shape, dtype=dtype) |
| 185 | + s = scalars.draw(from_type(scalar_type)) |
| 186 | + scalar_as_array = _array_module.full((), s, dtype=dtype) |
| 187 | + get_locals = lambda: dict(a=a, s=s, scalar_as_array=scalar_as_array) |
| 188 | + |
| 189 | + # As per the spec: |
| 190 | + |
| 191 | + # The expected behavior is then equivalent to: |
| 192 | + # |
| 193 | + # 1. Convert the scalar to a 0-D array with the same dtype as that of the |
| 194 | + # array used in the expression. |
| 195 | + # |
| 196 | + # 2. Execute the operation for `array <op> 0-D array` (or `0-D array <op> |
| 197 | + # array` if `scalar` was the left-hand argument). |
| 198 | + |
| 199 | + array_scalar = f'a {binary_op} s' |
| 200 | + array_scalar_expected = f'a {binary_op} scalar_as_array' |
| 201 | + res = eval(array_scalar, get_locals()) |
| 202 | + expected = eval(array_scalar_expected, get_locals()) |
| 203 | + assert_exactly_equal(res, expected) |
| 204 | + |
| 205 | + scalar_array = f's {binary_op} a' |
| 206 | + scalar_array_expected = f'scalar_as_array {binary_op} a' |
| 207 | + res = eval(scalar_array, get_locals()) |
| 208 | + expected = eval(scalar_array_expected, get_locals()) |
| 209 | + assert_exactly_equal(res, expected) |
| 210 | + |
126 | 211 | if __name__ == '__main__':
|
127 | 212 | for (i, j), p in promotion_table.items():
|
128 | 213 | print(f"({i}, {j}) -> {p}")
|
0 commit comments