Skip to content

Commit 82e6312

Browse files
committed
Rudimentary complex testing for unary elwise functions
1 parent 0c2c0f7 commit 82e6312

File tree

4 files changed

+80
-30
lines changed

4 files changed

+80
-30
lines changed

array_api_tests/dtype_helpers.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"uint_dtypes",
1616
"all_int_dtypes",
1717
"float_dtypes",
18+
"real_dtypes",
1819
"numeric_dtypes",
1920
"all_dtypes",
2021
"dtype_to_name",
@@ -30,6 +31,7 @@
3031
"promotion_table",
3132
"dtype_nbits",
3233
"dtype_signed",
34+
"dtype_components",
3335
"func_in_dtypes",
3436
"func_returns_bool",
3537
"binary_op_to_symbol",
@@ -86,14 +88,19 @@ def __repr__(self):
8688
_uint_names = ("uint8", "uint16", "uint32", "uint64")
8789
_int_names = ("int8", "int16", "int32", "int64")
8890
_float_names = ("float32", "float64")
89-
_dtype_names = ("bool",) + _uint_names + _int_names + _float_names
91+
_real_names = _uint_names + _int_names + _float_names
92+
_complex_names = ("complex64", "complex128")
93+
_numeric_names = _real_names + _complex_names
94+
_dtype_names = ("bool",) + _numeric_names
9095

9196

9297
uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
9398
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
9499
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
95100
all_int_dtypes = uint_dtypes + int_dtypes
96-
numeric_dtypes = all_int_dtypes + float_dtypes
101+
real_dtypes = all_int_dtypes + float_dtypes
102+
complex_dtypes = tuple(getattr(xp, name) for name in _complex_names)
103+
numeric_dtypes = real_dtypes + complex_dtypes
97104
all_dtypes = (xp.bool,) + numeric_dtypes
98105
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
99106

@@ -129,6 +136,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
129136
return int
130137
elif is_float_dtype(dtype):
131138
return float
139+
elif dtype in complex_dtypes:
140+
return complex
132141
else:
133142
return bool
134143

@@ -157,7 +166,8 @@ class MinMax(NamedTuple):
157166
[(d, 8) for d in [xp.int8, xp.uint8]]
158167
+ [(d, 16) for d in [xp.int16, xp.uint16]]
159168
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
160-
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]]
169+
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]]
170+
+ [(xp.complex128, 128)]
161171
)
162172

163173

@@ -166,6 +176,11 @@ class MinMax(NamedTuple):
166176
)
167177

168178

179+
dtype_components = EqualityMapping(
180+
[(xp.complex64, xp.float32), (xp.complex128, xp.float64)]
181+
)
182+
183+
169184
if isinstance(xp.asarray, _UndefinedStub):
170185
default_int = xp.int32
171186
default_float = xp.float32
@@ -226,6 +241,11 @@ class MinMax(NamedTuple):
226241
((xp.float32, xp.float32), xp.float32),
227242
((xp.float32, xp.float64), xp.float64),
228243
((xp.float64, xp.float64), xp.float64),
244+
# complex
245+
((xp.complex64, xp.complex64), xp.complex64),
246+
((xp.complex64, xp.complex128), xp.complex128),
247+
((xp.complex128, xp.complex128), xp.complex128),
248+
229249
]
230250
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
231251
_promotion_table = list(set(_numeric_promotions))

array_api_tests/hypothesis_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
shared_dtypes = shared(dtypes, key="dtype")
4747
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
4848

49-
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes]
49+
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes]
5050
_sorted_dtypes = [d for category in _dtype_categories for d in category]
5151

5252
def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):

array_api_tests/pytest_helpers.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,21 @@ def assert_fill(
374374
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
375375

376376

377+
def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
378+
if xp.isnan(at_expected):
379+
assert xp.isnan(at_out), msg
380+
elif at_expected == 0.0 or at_expected == -0.0:
381+
scalar_at_expected = float(at_expected)
382+
scalar_at_out = float(at_out)
383+
if is_pos_zero(scalar_at_expected):
384+
assert is_pos_zero(scalar_at_out), msg
385+
else:
386+
assert is_neg_zero(scalar_at_expected) # sanity check
387+
assert is_neg_zero(scalar_at_out), msg
388+
else:
389+
assert at_out == at_expected, msg
390+
391+
377392
def assert_array_elements(
378393
func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw
379394
):
@@ -392,26 +407,26 @@ def assert_array_elements(
392407
dh.result_type(out.dtype, expected.dtype) # sanity check
393408
assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check
394409
f_func = f"[{func_name}({fmt_kw(kw)})]"
395-
if dh.is_float_dtype(out.dtype):
410+
if out.dtype in dh.float_dtypes:
411+
for idx in sh.ndindex(out.shape):
412+
at_out = out[idx]
413+
at_expected = expected[idx]
414+
msg = (
415+
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
416+
f"{f_func}"
417+
)
418+
_assert_float_element(at_out, at_expected, msg)
419+
elif out.dtype in dh.complex_dtypes:
420+
assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes)
396421
for idx in sh.ndindex(out.shape):
397422
at_out = out[idx]
398423
at_expected = expected[idx]
399424
msg = (
400425
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
401426
f"{f_func}"
402427
)
403-
if xp.isnan(at_expected):
404-
assert xp.isnan(at_out), msg
405-
elif at_expected == 0.0 or at_expected == -0.0:
406-
scalar_at_expected = float(at_expected)
407-
scalar_at_out = float(at_out)
408-
if is_pos_zero(scalar_at_expected):
409-
assert is_pos_zero(scalar_at_out), msg
410-
else:
411-
assert is_neg_zero(scalar_at_expected) # sanity check
412-
assert is_neg_zero(scalar_at_out), msg
413-
else:
414-
assert at_out == at_expected, msg
428+
_assert_float_element(at_out.real, at_expected.real, msg)
429+
_assert_float_element(at_out.imag, at_expected.imag, msg)
415430
else:
416431
assert xp.all(
417432
out == expected

array_api_tests/test_operators_and_elementwise_functions.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def default_filter(s: Scalar) -> bool:
103103
"""
104104
if isinstance(s, int): # note bools are ints
105105
return True
106+
elif isinstance(s, complex):
107+
return default_filter(s.real) and default_filter(s.imag)
106108
else:
107109
return math.isfinite(s) and s != 0
108110

@@ -247,7 +249,12 @@ def unary_assert_against_refimpl(
247249
in_stype = dh.get_scalar_type(in_.dtype)
248250
if res_stype is None:
249251
res_stype = in_stype
250-
m, M = dh.dtype_ranges.get(res.dtype, (None, None))
252+
if res.dtype == xp.bool:
253+
m, M = (None, None)
254+
if res.dtype in dh.complex_dtypes:
255+
m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]]
256+
else:
257+
m, M = dh.dtype_ranges[res.dtype]
251258
for idx in sh.ndindex(in_.shape):
252259
scalar_i = in_stype(in_[idx])
253260
if not filter_(scalar_i):
@@ -257,9 +264,13 @@ def unary_assert_against_refimpl(
257264
except Exception:
258265
continue
259266
if res.dtype != xp.bool:
260-
assert m is not None and M is not None # for mypy
261-
if expected <= m or expected >= M:
262-
continue
267+
if res.dtype in dh.complex_dtypes:
268+
for component in [expected.real, expected.imag]:
269+
if component <= m or expected >= M:
270+
continue
271+
else:
272+
if expected <= m or expected >= M:
273+
continue
263274
scalar_o = res_stype(res[idx])
264275
f_i = sh.fmt_idx("x", idx)
265276
f_o = sh.fmt_idx("out", idx)
@@ -418,8 +429,11 @@ def __repr__(self):
418429

419430

420431
def make_unary_params(
421-
elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType]
432+
elwise_func_name: str, dtypes: Sequence[DataType]
422433
) -> List[Param[UnaryParamContext]]:
434+
if hh.FILTER_UNDEFINED_DTYPES:
435+
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
436+
dtypes_strat = st.sampled_from(dtypes)
423437
strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes())
424438
func_ctx = UnaryParamContext(
425439
func_name=elwise_func_name, func=getattr(xp, elwise_func_name), strat=strat
@@ -633,7 +647,7 @@ def binary_param_assert_against_refimpl(
633647
)
634648

635649

636-
@pytest.mark.parametrize("ctx", make_unary_params("abs", xps.numeric_dtypes()))
650+
@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes))
637651
@given(data=st.data())
638652
def test_abs(ctx, data):
639653
x = data.draw(ctx.strat, label="x")
@@ -643,7 +657,10 @@ def test_abs(ctx, data):
643657

644658
out = ctx.func(x)
645659

646-
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
660+
if x.dtype in dh.complex_dtypes:
661+
assert out.dtype == dh.complex_components[x.dtype]
662+
else:
663+
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
647664
ph.assert_shape(ctx.func_name, out.shape, x.shape)
648665
unary_assert_against_refimpl(
649666
ctx.func_name,
@@ -783,7 +800,7 @@ def test_bitwise_left_shift(ctx, data):
783800

784801

785802
@pytest.mark.parametrize(
786-
"ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes())
803+
"ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes)
787804
)
788805
@given(data=st.data())
789806
def test_bitwise_invert(ctx, data):
@@ -1187,9 +1204,7 @@ def test_multiply(ctx, data):
11871204

11881205

11891206
# TODO: clarify if uints are acceptable, adjust accordingly
1190-
@pytest.mark.parametrize(
1191-
"ctx", make_unary_params("negative", xps.integer_dtypes() | xps.floating_dtypes())
1192-
)
1207+
@pytest.mark.parametrize("ctx", make_unary_params("negative", dh.numeric_dtypes))
11931208
@given(data=st.data())
11941209
def test_negative(ctx, data):
11951210
x = data.draw(ctx.strat, label="x")
@@ -1226,7 +1241,7 @@ def test_not_equal(ctx, data):
12261241
)
12271242

12281243

1229-
@pytest.mark.parametrize("ctx", make_unary_params("positive", xps.numeric_dtypes()))
1244+
@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes))
12301245
@given(data=st.data())
12311246
def test_positive(ctx, data):
12321247
x = data.draw(ctx.strat, label="x")
@@ -1317,7 +1332,7 @@ def test_square(x):
13171332
ph.assert_dtype("square", x.dtype, out.dtype)
13181333
ph.assert_shape("square", out.shape, x.shape)
13191334
unary_assert_against_refimpl(
1320-
"square", x, out, lambda s: s ** 2, expr_template="{}²={}"
1335+
"square", x, out, lambda s: s**2, expr_template="{}²={}"
13211336
)
13221337

13231338

0 commit comments

Comments
 (0)