Skip to content

Commit 1d82bd9

Browse files
committed
Limit main special case tests to one example for now
1 parent 66fe25e commit 1d82bd9

File tree

1 file changed

+41
-128
lines changed

1 file changed

+41
-128
lines changed

array_api_tests/test_special_cases.py

+41-128
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,14 @@
2323
from warnings import warn
2424

2525
import pytest
26-
from hypothesis import assume, given, note
26+
from hypothesis import given, note, settings
2727
from hypothesis import strategies as st
2828

2929
from array_api_tests.typing import Array, DataType
3030

3131
from . import dtype_helpers as dh
3232
from . import hypothesis_helpers as hh
3333
from . import pytest_helpers as ph
34-
from . import shape_helpers as sh
3534
from . import xp, xps
3635
from .stubs import category_to_funcs
3736

@@ -1210,143 +1209,57 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
12101209
assert len(iop_params) != 0
12111210

12121211

1213-
@pytest.mark.unvectorized
12141212
@pytest.mark.parametrize("func_name, func, case", unary_params)
1215-
@given(
1216-
x=hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
1217-
data=st.data(),
1218-
)
1219-
def test_unary(func_name, func, case, x, data):
1220-
set_idx = data.draw(
1221-
xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx"
1213+
def test_unary(func_name, func, case):
1214+
in_value = case.cond_from_dtype(xp.float64).example()
1215+
x = xp.asarray(in_value, dtype=xp.float64)
1216+
out = func(x)
1217+
out_value = float(out)
1218+
assert case.check_result(in_value, out_value), (
1219+
f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n"
12221220
)
1223-
set_value = data.draw(case.cond_from_dtype(x.dtype), label="set value")
1224-
x[set_idx] = set_value
1225-
note(f"{x=}")
1226-
1227-
res = func(x)
1228-
1229-
good_example = False
1230-
for idx in sh.ndindex(res.shape):
1231-
in_ = float(x[idx])
1232-
if case.cond(in_):
1233-
good_example = True
1234-
out = float(res[idx])
1235-
f_in = f"{sh.fmt_idx('x', idx)}={in_}"
1236-
f_out = f"{sh.fmt_idx('out', idx)}={out}"
1237-
assert case.check_result(in_, out), (
1238-
f"{f_out}, but should be {case.result_expr} [{func_name}()]\n"
1239-
f"condition: {case.cond_expr}\n"
1240-
f"{f_in}"
1241-
)
1242-
break
1243-
assume(good_example)
1244-
1245-
1246-
x1_strat, x2_strat = hh.two_mutual_arrays(
1247-
dtypes=dh.real_float_dtypes,
1248-
two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1),
1249-
)
12501221

12511222

1252-
@pytest.mark.unvectorized
12531223
@pytest.mark.parametrize("func_name, func, case", binary_params)
1254-
@given(x1=x1_strat, x2=x2_strat, data=st.data())
1255-
def test_binary(func_name, func, case, x1, x2, data):
1256-
result_shape = sh.broadcast_shapes(x1.shape, x2.shape)
1257-
all_indices = list(sh.iter_indices(x1.shape, x2.shape, result_shape))
1258-
1259-
indices_strat = st.shared(st.sampled_from(all_indices))
1260-
set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx")
1261-
set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value")
1262-
x1[set_x1_idx] = set_x1_value
1263-
note(f"{x1=}")
1264-
set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx")
1265-
set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value")
1266-
x2[set_x2_idx] = set_x2_value
1267-
note(f"{x2=}")
1268-
1269-
res = func(x1, x2)
1270-
# sanity check
1271-
ph.assert_result_shape(
1272-
func_name,
1273-
in_shapes=[x1.shape, x2.shape],
1274-
out_shape=res.shape,
1275-
expected=result_shape,
1224+
@settings(max_examples=1)
1225+
@given(data=st.data())
1226+
def test_binary(func_name, func, case, data):
1227+
# We don't use example() like in test_unary because the same internal shared
1228+
# strategies used in both x1's and x2's don't "sync" with example() draws.
1229+
x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value")
1230+
x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value")
1231+
x1 = xp.asarray(x1_value, dtype=xp.float64)
1232+
x2 = xp.asarray(x2_value, dtype=xp.float64)
1233+
1234+
out = func(x1, x2)
1235+
out_value = float(out)
1236+
1237+
assert case.check_result(x1_value, x2_value, out_value), (
1238+
f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n"
1239+
f"condition: {case}\n"
1240+
f"x1={x1_value}, x2={x2_value}"
12761241
)
12771242

1278-
good_example = False
1279-
for l_idx, r_idx, o_idx in all_indices:
1280-
l = float(x1[l_idx])
1281-
r = float(x2[r_idx])
1282-
if case.cond(l, r):
1283-
good_example = True
1284-
o = float(res[o_idx])
1285-
f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
1286-
f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
1287-
f_out = f"{sh.fmt_idx('out', o_idx)}={o}"
1288-
assert case.check_result(l, r, o), (
1289-
f"{f_out}, but should be {case.result_expr} [{func_name}()]\n"
1290-
f"condition: {case}\n"
1291-
f"{f_left}, {f_right}"
1292-
)
1293-
break
1294-
assume(good_example)
12951243

12961244

1297-
@pytest.mark.unvectorized
12981245
@pytest.mark.parametrize("iop_name, iop, case", iop_params)
1299-
@given(
1300-
oneway_dtypes=hh.oneway_promotable_dtypes(dh.real_float_dtypes),
1301-
oneway_shapes=hh.oneway_broadcastable_shapes(),
1302-
data=st.data(),
1303-
)
1304-
def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
1305-
x1 = data.draw(
1306-
hh.arrays(dtype=oneway_dtypes.result_dtype, shape=oneway_shapes.result_shape),
1307-
label="x1",
1246+
@settings(max_examples=1)
1247+
@given(data=st.data())
1248+
def test_iop(iop_name, iop, case, data):
1249+
# See test_binary comment
1250+
x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value")
1251+
x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value")
1252+
x1 = xp.asarray(x1_value, dtype=xp.float64)
1253+
x2 = xp.asarray(x2_value, dtype=xp.float64)
1254+
1255+
res = iop(x1, x2)
1256+
res_value = float(res)
1257+
1258+
assert case.check_result(x1_value, x2_value, res_value), (
1259+
f"x1={res}, but should be {case.result_expr} [{func_name}()]\n"
1260+
f"condition: {case}\n"
1261+
f"x1={x1_value}, x2={x2_value}"
13081262
)
1309-
x2 = data.draw(
1310-
hh.arrays(dtype=oneway_dtypes.input_dtype, shape=oneway_shapes.input_shape),
1311-
label="x2",
1312-
)
1313-
1314-
all_indices = list(sh.iter_indices(x1.shape, x2.shape, x1.shape))
1315-
1316-
indices_strat = st.shared(st.sampled_from(all_indices))
1317-
set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx")
1318-
set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value")
1319-
x1[set_x1_idx] = set_x1_value
1320-
note(f"{x1=}")
1321-
set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx")
1322-
set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value")
1323-
x2[set_x2_idx] = set_x2_value
1324-
note(f"{x2=}")
1325-
1326-
res = xp.asarray(x1, copy=True)
1327-
res = iop(res, x2)
1328-
# sanity check
1329-
ph.assert_result_shape(
1330-
iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape
1331-
)
1332-
1333-
good_example = False
1334-
for l_idx, r_idx, o_idx in all_indices:
1335-
l = float(x1[l_idx])
1336-
r = float(x2[r_idx])
1337-
if case.cond(l, r):
1338-
good_example = True
1339-
o = float(res[o_idx])
1340-
f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
1341-
f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
1342-
f_out = f"{sh.fmt_idx('out', o_idx)}={o}"
1343-
assert case.check_result(l, r, o), (
1344-
f"{f_out}, but should be {case.result_expr} [{iop_name}()]\n"
1345-
f"condition: {case}\n"
1346-
f"{f_left}, {f_right}"
1347-
)
1348-
break
1349-
assume(good_example)
13501263

13511264

13521265
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)