Skip to content

Commit cd941c9

Browse files
authored
Merge pull request #81 from honno/more-tests
Make some type promotion tests standalone, dtype info tests
2 parents a313c3d + 3ed56a9 commit cd941c9

File tree

4 files changed

+104
-56
lines changed

4 files changed

+104
-56
lines changed

array_api_tests/test_creation_functions.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
2-
import pytest
32
from itertools import count
43
from typing import Iterator, NamedTuple, Union
54

5+
import pytest
66
from hypothesis import assume, given, note
77
from hypothesis import strategies as st
88

@@ -190,10 +190,7 @@ def test_arange(dtype, data):
190190
), f"out[0]={out[0]}, but should be {_start} {f_func}"
191191

192192

193-
@given(
194-
shape=hh.shapes(min_side=1),
195-
data=st.data(),
196-
)
193+
@given(shape=hh.shapes(min_side=1), data=st.data())
197194
def test_asarray_scalars(shape, data):
198195
kw = data.draw(
199196
hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw"
@@ -482,6 +479,29 @@ def test_linspace(num, dtype, endpoint, data):
482479
ah.assert_exactly_equal(out, expected)
483480

484481

482+
@given(
483+
# The number and size of generated arrays is arbitrarily limited to prevent
484+
# meshgrid() running out of memory.
485+
dtypes=hh.mutually_promotable_dtypes(5, dtypes=dh.numeric_dtypes),
486+
data=st.data(),
487+
)
488+
def test_meshgrid(dtypes, data):
489+
arrays = []
490+
shapes = data.draw(
491+
hh.mutually_broadcastable_shapes(
492+
len(dtypes), min_dims=1, max_dims=1, max_side=5
493+
),
494+
label="shapes",
495+
)
496+
for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1):
497+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
498+
arrays.append(x)
499+
assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check
500+
out = xp.meshgrid(*arrays)
501+
for i, x in enumerate(out):
502+
ph.assert_dtype("meshgrid", dtypes, x.dtype, repr_name=f"out[{i}].dtype")
503+
504+
485505
def make_one(dtype: DataType) -> Scalar:
486506
if dtype is None or dh.is_float_dtype(dtype):
487507
return 1.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
from hypothesis import given
3+
4+
from . import _array_module as xp
5+
from . import dtype_helpers as dh
6+
from . import hypothesis_helpers as hh
7+
from . import pytest_helpers as ph
8+
from .typing import DataType
9+
10+
11+
def make_dtype_id(dtype: DataType) -> str:
12+
return dh.dtype_to_name[dtype]
13+
14+
15+
@pytest.mark.parametrize("dtype", dh.float_dtypes, ids=make_dtype_id)
16+
def test_finfo(dtype):
17+
out = xp.finfo(dtype)
18+
f_func = f"[finfo({dh.dtype_to_name[dtype]})]"
19+
for attr, stype in [
20+
("bits", int),
21+
("eps", float),
22+
("max", float),
23+
("min", float),
24+
("smallest_normal", float),
25+
]:
26+
assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}"
27+
value = getattr(out, attr)
28+
assert isinstance(
29+
value, stype
30+
), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}"
31+
# TODO: test values
32+
33+
34+
@pytest.mark.parametrize("dtype", dh.all_int_dtypes, ids=make_dtype_id)
35+
def test_iinfo(dtype):
36+
out = xp.iinfo(dtype)
37+
f_func = f"[iinfo({dh.dtype_to_name[dtype]})]"
38+
for attr in ["bits", "max", "min"]:
39+
assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}"
40+
value = getattr(out, attr)
41+
assert isinstance(
42+
value, int
43+
), f"type(out.{attr})={type(value)!r}, but should be int {f_func}"
44+
# TODO: test values
45+
46+
47+
@given(hh.mutually_promotable_dtypes(None))
48+
def test_result_type(dtypes):
49+
out = xp.result_type(*dtypes)
50+
ph.assert_dtype("result_type", dtypes, out, repr_name="out")

array_api_tests/test_linalg.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
from hypothesis import assume, given
1818
from hypothesis.strategies import (booleans, composite, none, tuples, integers,
19-
shared, sampled_from)
19+
shared, sampled_from, data, just)
2020

2121
from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity
2222
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
@@ -33,6 +33,7 @@
3333
from .algos import broadcast_shapes
3434

3535
from . import _array_module
36+
from . import _array_module as xp
3637
from ._array_module import linalg
3738

3839
pytestmark = pytest.mark.ci
@@ -556,13 +557,20 @@ def test_svdvals(x):
556557

557558

558559
@given(
559-
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
560-
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
561-
kw=kwargs(axes=todo)
560+
dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes),
561+
shape=shapes(),
562+
data=data(),
562563
)
563-
def test_tensordot(x1, x2, kw):
564-
# res = _array_module.tensordot(x1, x2, **kw)
565-
pass
564+
def test_tensordot(dtypes, shape, data):
565+
# TODO: vary shapes, vary contracted axes, test different axes arguments
566+
x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shape), label="x1")
567+
x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shape), label="x2")
568+
569+
out = xp.tensordot(x1, x2, axes=len(shape))
570+
571+
ph.assert_dtype("tensordot", dtypes, out.dtype)
572+
# TODO: assert shape and elements
573+
566574

567575
@pytest.mark.xp_extension('linalg')
568576
@given(
@@ -605,13 +613,21 @@ def true_trace(x_stack):
605613

606614

607615
@given(
608-
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
609-
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
610-
kw=kwargs(axis=todo)
616+
dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes),
617+
shape=shapes(),
618+
data=data(),
611619
)
612-
def test_vecdot(x1, x2, kw):
613-
# res = _array_module.vecdot(x1, x2, **kw)
614-
pass
620+
def test_vecdot(dtypes, shape, data):
621+
# TODO: vary shapes, test different axis arguments
622+
x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shape), label="x1")
623+
x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shape), label="x2")
624+
kw = data.draw(kwargs(axis=just(-1)))
625+
626+
out = xp.vecdot(x1, x2, **kw)
627+
628+
ph.assert_dtype("vecdot", dtypes, out.dtype)
629+
# TODO: assert shape and elements
630+
615631

616632
@pytest.mark.xp_extension('linalg')
617633
@given(

array_api_tests/test_type_promotion.py

-38
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,6 @@
1616
from .function_stubs import elementwise_functions
1717
from .typing import DataType, Param, ScalarType
1818

19-
# TODO: move tests not covering elementwise funcs/ops into standalone tests
20-
# result_type, meshgrid, tensordor, vecdot
21-
22-
23-
@given(hh.mutually_promotable_dtypes(None))
24-
def test_result_type(dtypes):
25-
out = xp.result_type(*dtypes)
26-
ph.assert_dtype("result_type", dtypes, out, repr_name="out")
27-
28-
2919
bitwise_shift_funcs = [
3020
"bitwise_left_shift",
3121
"bitwise_right_shift",
@@ -133,37 +123,9 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
133123
promotion_params.append(p)
134124

135125

136-
@pytest.mark.parametrize("in_dtypes, out_dtype", promotion_params)
137-
@given(shapes=hh.mutually_broadcastable_shapes(3), data=st.data())
138-
def test_where(in_dtypes, out_dtype, shapes, data):
139-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
140-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
141-
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label="condition")
142-
out = xp.where(cond, x1, x2)
143-
ph.assert_dtype("where", in_dtypes, out.dtype, out_dtype)
144-
145-
146126
numeric_promotion_params = promotion_params[1:]
147127

148128

149-
@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params)
150-
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=2), data=st.data())
151-
def test_tensordot(in_dtypes, out_dtype, shapes, data):
152-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
153-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
154-
out = xp.tensordot(x1, x2)
155-
ph.assert_dtype("tensordot", in_dtypes, out.dtype, out_dtype)
156-
157-
158-
@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params)
159-
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1), data=st.data())
160-
def test_vecdot(in_dtypes, out_dtype, shapes, data):
161-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
162-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
163-
out = xp.vecdot(x1, x2)
164-
ph.assert_dtype("vecdot", in_dtypes, out.dtype, out_dtype)
165-
166-
167129
op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
168130
op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
169131
for op, symbol in op_to_symbol.items():

0 commit comments

Comments
 (0)