|
23 | 23 | from warnings import warn
|
24 | 24 |
|
25 | 25 | import pytest
|
26 |
| -from hypothesis import assume, given, note |
| 26 | +from hypothesis import given, note, settings |
27 | 27 | from hypothesis import strategies as st
|
28 | 28 |
|
29 | 29 | from array_api_tests.typing import Array, DataType
|
30 | 30 |
|
31 | 31 | from . import dtype_helpers as dh
|
32 | 32 | from . import hypothesis_helpers as hh
|
33 | 33 | from . import pytest_helpers as ph
|
34 |
| -from . import shape_helpers as sh |
35 | 34 | from . import xp, xps
|
36 | 35 | from .stubs import category_to_funcs
|
37 | 36 |
|
@@ -1210,143 +1209,57 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
|
1210 | 1209 | assert len(iop_params) != 0
|
1211 | 1210 |
|
1212 | 1211 |
|
1213 |
| -@pytest.mark.unvectorized |
1214 | 1212 | @pytest.mark.parametrize("func_name, func, case", unary_params)
|
1215 |
| -@given( |
1216 |
| - x=hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), |
1217 |
| - data=st.data(), |
1218 |
| -) |
1219 |
| -def test_unary(func_name, func, case, x, data): |
1220 |
| - set_idx = data.draw( |
1221 |
| - xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" |
| 1213 | +def test_unary(func_name, func, case): |
| 1214 | + in_value = case.cond_from_dtype(xp.float64).example() |
| 1215 | + x = xp.asarray(in_value, dtype=xp.float64) |
| 1216 | + out = func(x) |
| 1217 | + out_value = float(out) |
| 1218 | + assert case.check_result(in_value, out_value), ( |
| 1219 | + f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" |
1222 | 1220 | )
|
1223 |
| - set_value = data.draw(case.cond_from_dtype(x.dtype), label="set value") |
1224 |
| - x[set_idx] = set_value |
1225 |
| - note(f"{x=}") |
1226 |
| - |
1227 |
| - res = func(x) |
1228 |
| - |
1229 |
| - good_example = False |
1230 |
| - for idx in sh.ndindex(res.shape): |
1231 |
| - in_ = float(x[idx]) |
1232 |
| - if case.cond(in_): |
1233 |
| - good_example = True |
1234 |
| - out = float(res[idx]) |
1235 |
| - f_in = f"{sh.fmt_idx('x', idx)}={in_}" |
1236 |
| - f_out = f"{sh.fmt_idx('out', idx)}={out}" |
1237 |
| - assert case.check_result(in_, out), ( |
1238 |
| - f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" |
1239 |
| - f"condition: {case.cond_expr}\n" |
1240 |
| - f"{f_in}" |
1241 |
| - ) |
1242 |
| - break |
1243 |
| - assume(good_example) |
1244 |
| - |
1245 |
| - |
1246 |
| -x1_strat, x2_strat = hh.two_mutual_arrays( |
1247 |
| - dtypes=dh.real_float_dtypes, |
1248 |
| - two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), |
1249 |
| -) |
1250 | 1221 |
|
1251 | 1222 |
|
1252 |
| -@pytest.mark.unvectorized |
1253 | 1223 | @pytest.mark.parametrize("func_name, func, case", binary_params)
|
1254 |
| -@given(x1=x1_strat, x2=x2_strat, data=st.data()) |
1255 |
| -def test_binary(func_name, func, case, x1, x2, data): |
1256 |
| - result_shape = sh.broadcast_shapes(x1.shape, x2.shape) |
1257 |
| - all_indices = list(sh.iter_indices(x1.shape, x2.shape, result_shape)) |
1258 |
| - |
1259 |
| - indices_strat = st.shared(st.sampled_from(all_indices)) |
1260 |
| - set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") |
1261 |
| - set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") |
1262 |
| - x1[set_x1_idx] = set_x1_value |
1263 |
| - note(f"{x1=}") |
1264 |
| - set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") |
1265 |
| - set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") |
1266 |
| - x2[set_x2_idx] = set_x2_value |
1267 |
| - note(f"{x2=}") |
1268 |
| - |
1269 |
| - res = func(x1, x2) |
1270 |
| - # sanity check |
1271 |
| - ph.assert_result_shape( |
1272 |
| - func_name, |
1273 |
| - in_shapes=[x1.shape, x2.shape], |
1274 |
| - out_shape=res.shape, |
1275 |
| - expected=result_shape, |
| 1224 | +@settings(max_examples=1) |
| 1225 | +@given(data=st.data()) |
| 1226 | +def test_binary(func_name, func, case, data): |
| 1227 | + # We don't use example() like in test_unary because the same internal shared |
| 1228 | + # strategies used in both x1's and x2's don't "sync" with example() draws. |
| 1229 | + x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") |
| 1230 | + x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") |
| 1231 | + x1 = xp.asarray(x1_value, dtype=xp.float64) |
| 1232 | + x2 = xp.asarray(x2_value, dtype=xp.float64) |
| 1233 | + |
| 1234 | + out = func(x1, x2) |
| 1235 | + out_value = float(out) |
| 1236 | + |
| 1237 | + assert case.check_result(x1_value, x2_value, out_value), ( |
| 1238 | + f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" |
| 1239 | + f"condition: {case}\n" |
| 1240 | + f"x1={x1_value}, x2={x2_value}" |
1276 | 1241 | )
|
1277 | 1242 |
|
1278 |
| - good_example = False |
1279 |
| - for l_idx, r_idx, o_idx in all_indices: |
1280 |
| - l = float(x1[l_idx]) |
1281 |
| - r = float(x2[r_idx]) |
1282 |
| - if case.cond(l, r): |
1283 |
| - good_example = True |
1284 |
| - o = float(res[o_idx]) |
1285 |
| - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" |
1286 |
| - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" |
1287 |
| - f_out = f"{sh.fmt_idx('out', o_idx)}={o}" |
1288 |
| - assert case.check_result(l, r, o), ( |
1289 |
| - f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" |
1290 |
| - f"condition: {case}\n" |
1291 |
| - f"{f_left}, {f_right}" |
1292 |
| - ) |
1293 |
| - break |
1294 |
| - assume(good_example) |
1295 | 1243 |
|
1296 | 1244 |
|
1297 |
| -@pytest.mark.unvectorized |
1298 | 1245 | @pytest.mark.parametrize("iop_name, iop, case", iop_params)
|
1299 |
| -@given( |
1300 |
| - oneway_dtypes=hh.oneway_promotable_dtypes(dh.real_float_dtypes), |
1301 |
| - oneway_shapes=hh.oneway_broadcastable_shapes(), |
1302 |
| - data=st.data(), |
1303 |
| -) |
1304 |
| -def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): |
1305 |
| - x1 = data.draw( |
1306 |
| - hh.arrays(dtype=oneway_dtypes.result_dtype, shape=oneway_shapes.result_shape), |
1307 |
| - label="x1", |
| 1246 | +@settings(max_examples=1) |
| 1247 | +@given(data=st.data()) |
| 1248 | +def test_iop(iop_name, iop, case, data): |
| 1249 | + # See test_binary comment |
| 1250 | + x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") |
| 1251 | + x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") |
| 1252 | + x1 = xp.asarray(x1_value, dtype=xp.float64) |
| 1253 | + x2 = xp.asarray(x2_value, dtype=xp.float64) |
| 1254 | + |
| 1255 | + res = iop(x1, x2) |
| 1256 | + res_value = float(res) |
| 1257 | + |
| 1258 | + assert case.check_result(x1_value, x2_value, res_value), ( |
| 1259 | + f"x1={res}, but should be {case.result_expr} [{func_name}()]\n" |
| 1260 | + f"condition: {case}\n" |
| 1261 | + f"x1={x1_value}, x2={x2_value}" |
1308 | 1262 | )
|
1309 |
| - x2 = data.draw( |
1310 |
| - hh.arrays(dtype=oneway_dtypes.input_dtype, shape=oneway_shapes.input_shape), |
1311 |
| - label="x2", |
1312 |
| - ) |
1313 |
| - |
1314 |
| - all_indices = list(sh.iter_indices(x1.shape, x2.shape, x1.shape)) |
1315 |
| - |
1316 |
| - indices_strat = st.shared(st.sampled_from(all_indices)) |
1317 |
| - set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") |
1318 |
| - set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") |
1319 |
| - x1[set_x1_idx] = set_x1_value |
1320 |
| - note(f"{x1=}") |
1321 |
| - set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") |
1322 |
| - set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") |
1323 |
| - x2[set_x2_idx] = set_x2_value |
1324 |
| - note(f"{x2=}") |
1325 |
| - |
1326 |
| - res = xp.asarray(x1, copy=True) |
1327 |
| - res = iop(res, x2) |
1328 |
| - # sanity check |
1329 |
| - ph.assert_result_shape( |
1330 |
| - iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape |
1331 |
| - ) |
1332 |
| - |
1333 |
| - good_example = False |
1334 |
| - for l_idx, r_idx, o_idx in all_indices: |
1335 |
| - l = float(x1[l_idx]) |
1336 |
| - r = float(x2[r_idx]) |
1337 |
| - if case.cond(l, r): |
1338 |
| - good_example = True |
1339 |
| - o = float(res[o_idx]) |
1340 |
| - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" |
1341 |
| - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" |
1342 |
| - f_out = f"{sh.fmt_idx('out', o_idx)}={o}" |
1343 |
| - assert case.check_result(l, r, o), ( |
1344 |
| - f"{f_out}, but should be {case.result_expr} [{iop_name}()]\n" |
1345 |
| - f"condition: {case}\n" |
1346 |
| - f"{f_left}, {f_right}" |
1347 |
| - ) |
1348 |
| - break |
1349 |
| - assume(good_example) |
1350 | 1263 |
|
1351 | 1264 |
|
1352 | 1265 | @pytest.mark.parametrize(
|
|
0 commit comments