Skip to content

Commit 613ae94

Browse files
committed
Add test for array <binary op> scalar
This is based on what the spec currently says, but there are some issues with it (see data-apis/array-api#98).
1 parent 6557c79 commit 613ae94

File tree

1 file changed

+86
-1
lines changed

1 file changed

+86
-1
lines changed

array_api_tests/test_type_promotion.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""
2-
https://github.com/data-apis/array-api/blob/master/spec/API_specification/type_promotion.md
2+
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
33
"""
44

55
import pytest
66

77
from hypothesis import given, example
8+
from hypothesis.strategies import from_type, data
89

910
from .hypothesis_helpers import shapes
1011
from .pytest_helpers import nargs
12+
from .array_helpers import assert_exactly_equal
1113

1214
from .function_stubs import elementwise_functions
1315
from ._array_module import (ones, int8, int16, int32, int64, uint8,
@@ -94,6 +96,52 @@
9496
**float_promotion_table,
9597
}
9698

99+
100+
binary_operators = {
101+
'__add__': '+',
102+
'__and__': '&',
103+
'__eq__': '==',
104+
'__floordiv__': '//',
105+
'__ge__': '>=',
106+
'__gt__': '>',
107+
'__le__': '<=',
108+
'__lshift__': '<<',
109+
'__lt__': '<',
110+
'__matmul__': '@',
111+
'__mod__': '%',
112+
'__mul__': '*',
113+
'__ne__': '!=',
114+
'__or__': '|',
115+
'__pow__': '**',
116+
'__rshift__': '>>',
117+
'__sub__': '-',
118+
'__truediv__': '/',
119+
'__xor__': '^',
120+
}
121+
122+
unary_operators = {
123+
'__invert__': '~',
124+
'__neg__': '-',
125+
'__pos__': '+',
126+
}
127+
128+
dtypes_to_scalar = {
129+
_array_module.bool: bool,
130+
_array_module.int8: int,
131+
_array_module.int16: int,
132+
_array_module.int32: int,
133+
_array_module.int64: int,
134+
_array_module.uint8: int,
135+
_array_module.uint16: int,
136+
_array_module.uint32: int,
137+
_array_module.uint64: int,
138+
_array_module.float32: float,
139+
_array_module.float64: float,
140+
}
141+
142+
scalar_to_dtype = {s: [d for d, _s in dtypes_to_scalar.items() if _s == s] for
143+
s in dtypes_to_scalar.values()}
144+
97145
# TODO: Extend this to all functions (not just elementwise), and handle
98146
# functions that take more than 2 args
99147
@pytest.mark.parametrize('func_name', [i for i in
@@ -123,6 +171,43 @@ def test_promotion(func_name, shape, dtypes):
123171

124172
assert res.dtype == res_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape})"
125173

174+
@pytest.mark.parametrize('binary_op', sorted(set(binary_operators.values()) - {'@'}))
175+
@pytest.mark.parametrize('scalar_type,dtype', [(s, d) for s in scalar_to_dtype
176+
for d in scalar_to_dtype[s]])
177+
@given(shape=shapes, scalars=data())
178+
def test_operator_scalar_promotion(binary_op, scalar_type, dtype, shape, scalars):
179+
"""
180+
See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
181+
"""
182+
if binary_op == '@':
183+
pytest.skip("matmul (@) is not supported for scalars")
184+
a = ones(shape, dtype=dtype)
185+
s = scalars.draw(from_type(scalar_type))
186+
scalar_as_array = _array_module.full((), s, dtype=dtype)
187+
get_locals = lambda: dict(a=a, s=s, scalar_as_array=scalar_as_array)
188+
189+
# As per the spec:
190+
191+
# The expected behavior is then equivalent to:
192+
#
193+
# 1. Convert the scalar to a 0-D array with the same dtype as that of the
194+
# array used in the expression.
195+
#
196+
# 2. Execute the operation for `array <op> 0-D array` (or `0-D array <op>
197+
# array` if `scalar` was the left-hand argument).
198+
199+
array_scalar = f'a {binary_op} s'
200+
array_scalar_expected = f'a {binary_op} scalar_as_array'
201+
res = eval(array_scalar, get_locals())
202+
expected = eval(array_scalar_expected, get_locals())
203+
assert_exactly_equal(res, expected)
204+
205+
scalar_array = f's {binary_op} a'
206+
scalar_array_expected = f'scalar_as_array {binary_op} a'
207+
res = eval(scalar_array, get_locals())
208+
expected = eval(scalar_array_expected, get_locals())
209+
assert_exactly_equal(res, expected)
210+
126211
if __name__ == '__main__':
127212
for (i, j), p in promotion_table.items():
128213
print(f"({i}, {j}) -> {p}")

0 commit comments

Comments
 (0)