Skip to content

Commit 3315e18

Browse files
committed
two_mutual_arrays() returns two strategies
1 parent 5b0c49a commit 3315e18

File tree

3 files changed

+64
-76
lines changed

3 files changed

+64
-76
lines changed

array_api_tests/hypothesis_helpers.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def two_broadcastable_shapes(draw):
156156
sizes = integers(0, MAX_ARRAY_SIZE)
157157
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
158158

159-
# TODO: Generate general arrays here, rather than just scalars.
160159
numeric_arrays = xps.arrays(
161160
dtype=shared(xps.floating_dtypes(), key='dtypes'),
162161
shape=shared(xps.array_shapes(), key='shapes'),
@@ -267,13 +266,18 @@ def multiaxis_indices(draw, shapes):
267266
return tuple(res)
268267

269268

270-
@composite
271-
def two_mutual_arrays(draw, dtypes=dtype_objects):
272-
dtype1, dtype2 = draw(mutually_promotable_dtypes(dtypes))
273-
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
274-
x1 = draw(xps.arrays(dtype=dtype1, shape=shape1))
275-
x2 = draw(xps.arrays(dtype=dtype2, shape=shape2))
276-
return x1, x2
269+
def two_mutual_arrays(dtypes=dtype_objects):
270+
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes))
271+
mutual_shapes = shared(two_mutually_broadcastable_shapes)
272+
arrays1 = xps.arrays(
273+
dtype=mutual_dtypes.map(lambda pair: pair[0]),
274+
shape=mutual_shapes.map(lambda pair: pair[0]),
275+
)
276+
arrays2 = xps.arrays(
277+
dtype=mutual_dtypes.map(lambda pair: pair[1]),
278+
shape=mutual_shapes.map(lambda pair: pair[1]),
279+
)
280+
return arrays1, arrays2
277281

278282

279283
@composite

array_api_tests/meta_tests/test_hypothesis_helpers.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .._array_module import _UndefinedStub
88
from .. import array_helpers as ah
99
from .. import hypothesis_helpers as hh
10+
from ..test_broadcasting import broadcast_shapes
11+
from ..test_elementwise_functions import sanity_check
1012

1113
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in ah.dtype_objects)
1214
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
@@ -44,10 +46,13 @@ def test_two_mutually_broadcastable_shapes(pair):
4446
def test_two_broadcastable_shapes(pair):
4547
for shape in pair:
4648
assert valid_shape(shape)
49+
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
4750

48-
from ..test_broadcasting import broadcast_shapes
4951

50-
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
52+
@given(*hh.two_mutual_arrays())
53+
def test_two_mutual_arrays(x1, x2):
54+
sanity_check(x1, x2)
55+
assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape)
5156

5257

5358
def test_kwargs():

array_api_tests/test_elementwise_functions.py

+45-66
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def two_array_scalars(draw, dtype1, dtype2):
4141
# hh.mutually_promotable_dtypes())
4242
return draw(hh.array_scalars(st.just(dtype1))), draw(hh.array_scalars(st.just(dtype2)))
4343

44+
# TODO: refactor this into dtype_helpers.py, see https://github.com/data-apis/array-api-tests/pull/26
4445
def sanity_check(x1, x2):
4546
try:
4647
ah.promote_dtypes(x1.dtype, x2.dtype)
@@ -90,9 +91,8 @@ def test_acosh(x):
9091
# to nan, which is already tested in the special cases.
9192
ah.assert_exactly_equal(domain, codomain)
9293

93-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
94-
def test_add(x1_and_x2):
95-
x1, x2 = x1_and_x2
94+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
95+
def test_add(x1, x2):
9696
sanity_check(x1, x2)
9797
a = xp.add(x1, x2)
9898

@@ -133,9 +133,8 @@ def test_atan(x):
133133
# mapped to nan, which is already tested in the special cases.
134134
ah.assert_exactly_equal(domain, codomain)
135135

136-
@given(hh.two_mutual_arrays(hh.floating_dtype_objects))
137-
def test_atan2(x1_and_x2):
138-
x1, x2 = x1_and_x2
136+
@given(*hh.two_mutual_arrays(hh.floating_dtype_objects))
137+
def test_atan2(x1, x2):
139138
sanity_check(x1, x2)
140139
a = xp.atan2(x1, x2)
141140
INFINITY1 = ah.infinity(x1.shape, x1.dtype)
@@ -181,10 +180,9 @@ def test_atanh(x):
181180
# mapped to nan, which is already tested in the special cases.
182181
ah.assert_exactly_equal(domain, codomain)
183182

184-
@given(hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects))
185-
def test_bitwise_and(x1_and_x2):
183+
@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects))
184+
def test_bitwise_and(x1, x2):
186185
from .test_type_promotion import dtype_nbits, dtype_signed
187-
x1, x2 = x1_and_x2
188186
sanity_check(x1, x2)
189187
out = xp.bitwise_and(x1, x2)
190188

@@ -211,10 +209,9 @@ def test_bitwise_and(x1_and_x2):
211209
assert vals_and == res
212210

213211

214-
@given(hh.two_mutual_arrays(ah.integer_dtype_objects))
215-
def test_bitwise_left_shift(x1_and_x2):
212+
@given(*hh.two_mutual_arrays(ah.integer_dtype_objects))
213+
def test_bitwise_left_shift(x1, x2):
216214
from .test_type_promotion import dtype_nbits, dtype_signed
217-
x1, x2 = x1_and_x2
218215
sanity_check(x1, x2)
219216
assume(not ah.any(ah.isnegative(x2)))
220217
out = xp.bitwise_left_shift(x1, x2)
@@ -254,10 +251,9 @@ def test_bitwise_invert(x):
254251
val_invert = ah.int_to_dtype(val_invert, dtype_nbits(out.dtype), dtype_signed(out.dtype))
255252
assert val_invert == res
256253

257-
@given(hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects))
258-
def test_bitwise_or(x1_and_x2):
254+
@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects))
255+
def test_bitwise_or(x1, x2):
259256
from .test_type_promotion import dtype_nbits, dtype_signed
260-
x1, x2 = x1_and_x2
261257
sanity_check(x1, x2)
262258
out = xp.bitwise_or(x1, x2)
263259

@@ -283,10 +279,9 @@ def test_bitwise_or(x1_and_x2):
283279
vals_or = ah.int_to_dtype(vals_or, dtype_nbits(out.dtype), dtype_signed(out.dtype))
284280
assert vals_or == res
285281

286-
@given(hh.two_mutual_arrays(ah.integer_dtype_objects))
287-
def test_bitwise_right_shift(x1_and_x2):
282+
@given(*hh.two_mutual_arrays(ah.integer_dtype_objects))
283+
def test_bitwise_right_shift(x1, x2):
288284
from .test_type_promotion import dtype_nbits, dtype_signed
289-
x1, x2 = x1_and_x2
290285
sanity_check(x1, x2)
291286
assume(not ah.any(ah.isnegative(x2)))
292287
out = xp.bitwise_right_shift(x1, x2)
@@ -306,10 +301,9 @@ def test_bitwise_right_shift(x1_and_x2):
306301
vals_shift = ah.int_to_dtype(vals_shift, dtype_nbits(out.dtype), dtype_signed(out.dtype))
307302
assert vals_shift == res
308303

309-
@given(hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects))
310-
def test_bitwise_xor(x1_and_x2):
304+
@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects))
305+
def test_bitwise_xor(x1, x2):
311306
from .test_type_promotion import dtype_nbits, dtype_signed
312-
x1, x2 = x1_and_x2
313307
sanity_check(x1, x2)
314308
out = xp.bitwise_xor(x1, x2)
315309

@@ -367,9 +361,8 @@ def test_cosh(x):
367361
# mapped to nan, which is already tested in the special cases.
368362
ah.assert_exactly_equal(domain, codomain)
369363

370-
@given(hh.two_mutual_arrays(hh.floating_dtype_objects))
371-
def test_divide(x1_and_x2):
372-
x1, x2 = x1_and_x2
364+
@given(*hh.two_mutual_arrays(hh.floating_dtype_objects))
365+
def test_divide(x1, x2):
373366
sanity_check(x1, x2)
374367
xp.divide(x1, x2)
375368
# There isn't much we can test here. The spec doesn't require any behavior
@@ -379,9 +372,8 @@ def test_divide(x1_and_x2):
379372
# have those sorts in general for this module.
380373

381374

382-
@given(hh.two_mutual_arrays())
383-
def test_equal(x1_and_x2):
384-
x1, x2 = x1_and_x2
375+
@given(*hh.two_mutual_arrays())
376+
def test_equal(x1, x2):
385377
sanity_check(x1, x2)
386378
a = ah.equal(x1, x2)
387379
# NOTE: ah.assert_exactly_equal() itself uses ah.equal(), so we must be careful
@@ -461,9 +453,8 @@ def test_floor(x):
461453
integers = ah.isintegral(x)
462454
ah.assert_exactly_equal(a[integers], x[integers])
463455

464-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
465-
def test_floor_divide(x1_and_x2):
466-
x1, x2 = x1_and_x2
456+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
457+
def test_floor_divide(x1, x2):
467458
sanity_check(x1, x2)
468459
if ah.is_integer_dtype(x1.dtype):
469460
# The spec does not specify the behavior for division by 0 for integer
@@ -486,9 +477,8 @@ def test_floor_divide(x1_and_x2):
486477

487478
# TODO: Test the exact output for floor_divide.
488479

489-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
490-
def test_greater(x1_and_x2):
491-
x1, x2 = x1_and_x2
480+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
481+
def test_greater(x1, x2):
492482
sanity_check(x1, x2)
493483
a = xp.greater(x1, x2)
494484

@@ -516,9 +506,8 @@ def test_greater(x1_and_x2):
516506
assert aidx.shape == x1idx.shape == x2idx.shape
517507
assert bool(aidx) == (scalar_func(x1idx) > scalar_func(x2idx))
518508

519-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
520-
def test_greater_equal(x1_and_x2):
521-
x1, x2 = x1_and_x2
509+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
510+
def test_greater_equal(x1, x2):
522511
sanity_check(x1, x2)
523512
a = xp.greater_equal(x1, x2)
524513

@@ -592,9 +581,8 @@ def test_isnan(x):
592581
s = float(x[idx])
593582
assert bool(a[idx]) == math.isnan(s)
594583

595-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
596-
def test_less(x1_and_x2):
597-
x1, x2 = x1_and_x2
584+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
585+
def test_less(x1, x2):
598586
sanity_check(x1, x2)
599587
a = ah.less(x1, x2)
600588

@@ -622,9 +610,8 @@ def test_less(x1_and_x2):
622610
assert aidx.shape == x1idx.shape == x2idx.shape
623611
assert bool(aidx) == (scalar_func(x1idx) < scalar_func(x2idx))
624612

625-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
626-
def test_less_equal(x1_and_x2):
627-
x1, x2 = x1_and_x2
613+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
614+
def test_less_equal(x1, x2):
628615
sanity_check(x1, x2)
629616
a = ah.less_equal(x1, x2)
630617

@@ -696,18 +683,16 @@ def test_log10(x):
696683
# mapped to nan, which is already tested in the special cases.
697684
ah.assert_exactly_equal(domain, codomain)
698685

699-
@given(hh.two_mutual_arrays(hh.floating_dtype_objects))
700-
def test_logaddexp(x1_and_x2):
701-
x1, x2 = x1_and_x2
686+
@given(*hh.two_mutual_arrays(hh.floating_dtype_objects))
687+
def test_logaddexp(x1, x2):
702688
sanity_check(x1, x2)
703689
xp.logaddexp(x1, x2)
704690
# The spec doesn't require any behavior for this function. We could test
705691
# that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
706692
# don't have tests for this sort of thing for any functions yet.
707693

708-
@given(hh.two_mutual_arrays([xp.bool]))
709-
def test_logical_and(x1_and_x2):
710-
x1, x2 = x1_and_x2
694+
@given(*hh.two_mutual_arrays([xp.bool]))
695+
def test_logical_and(x1, x2):
711696
sanity_check(x1, x2)
712697
a = ah.logical_and(x1, x2)
713698

@@ -726,9 +711,8 @@ def test_logical_not(x):
726711
for idx in ah.ndindex(x.shape):
727712
assert a[idx] == (not bool(x[idx]))
728713

729-
@given(hh.two_mutual_arrays([xp.bool]))
730-
def test_logical_or(x1_and_x2):
731-
x1, x2 = x1_and_x2
714+
@given(*hh.two_mutual_arrays([xp.bool]))
715+
def test_logical_or(x1, x2):
732716
sanity_check(x1, x2)
733717
a = ah.logical_or(x1, x2)
734718

@@ -740,9 +724,8 @@ def test_logical_or(x1_and_x2):
740724
for idx in ah.ndindex(shape):
741725
assert a[idx] == (bool(_x1[idx]) or bool(_x2[idx]))
742726

743-
@given(hh.two_mutual_arrays([xp.bool]))
744-
def test_logical_xor(x1_and_x2):
745-
x1, x2 = x1_and_x2
727+
@given(*hh.two_mutual_arrays([xp.bool]))
728+
def test_logical_xor(x1, x2):
746729
sanity_check(x1, x2)
747730
a = xp.logical_xor(x1, x2)
748731

@@ -754,9 +737,8 @@ def test_logical_xor(x1_and_x2):
754737
for idx in ah.ndindex(shape):
755738
assert a[idx] == (bool(_x1[idx]) ^ bool(_x2[idx]))
756739

757-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
758-
def test_multiply(x1_and_x2):
759-
x1, x2 = x1_and_x2
740+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
741+
def test_multiply(x1, x2):
760742
sanity_check(x1, x2)
761743
a = xp.multiply(x1, x2)
762744

@@ -784,9 +766,8 @@ def test_negative(x):
784766
ah.assert_exactly_equal(y, ZERO)
785767

786768

787-
@given(hh.two_mutual_arrays())
788-
def test_not_equal(x1_and_x2):
789-
x1, x2 = x1_and_x2
769+
@given(*hh.two_mutual_arrays())
770+
def test_not_equal(x1, x2):
790771
sanity_check(x1, x2)
791772
a = xp.not_equal(x1, x2)
792773

@@ -821,9 +802,8 @@ def test_positive(x):
821802
# Positive does nothing
822803
ah.assert_exactly_equal(out, x)
823804

824-
@given(hh.two_mutual_arrays(hh.floating_dtype_objects))
825-
def test_pow(x1_and_x2):
826-
x1, x2 = x1_and_x2
805+
@given(*hh.two_mutual_arrays(hh.floating_dtype_objects))
806+
def test_pow(x1, x2):
827807
sanity_check(x1, x2)
828808
xp.pow(x1, x2)
829809
# There isn't much we can test here. The spec doesn't require any behavior
@@ -832,9 +812,8 @@ def test_pow(x1_and_x2):
832812
# numbers. We could test that this does implement IEEE 754 pow, but we
833813
# don't yet have those sorts in general for this module.
834814

835-
@given(hh.two_mutual_arrays(hh.numeric_dtype_objects))
836-
def test_remainder(x1_and_x2):
837-
x1, x2 = x1_and_x2
815+
@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects))
816+
def test_remainder(x1, x2):
838817
assume(len(x1.shape) <= len(x2.shape)) # TODO: rework same sign testing below to remove this
839818
sanity_check(x1, x2)
840819
out = xp.remainder(x1, x2)

0 commit comments

Comments
 (0)