@@ -41,6 +41,7 @@ def two_array_scalars(draw, dtype1, dtype2):
41
41
# hh.mutually_promotable_dtypes())
42
42
return draw (hh .array_scalars (st .just (dtype1 ))), draw (hh .array_scalars (st .just (dtype2 )))
43
43
44
+ # TODO: refactor this into dtype_helpers.py, see https://github.com/data-apis/array-api-tests/pull/26
44
45
def sanity_check (x1 , x2 ):
45
46
try :
46
47
ah .promote_dtypes (x1 .dtype , x2 .dtype )
@@ -90,9 +91,8 @@ def test_acosh(x):
90
91
# to nan, which is already tested in the special cases.
91
92
ah .assert_exactly_equal (domain , codomain )
92
93
93
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
94
- def test_add (x1_and_x2 ):
95
- x1 , x2 = x1_and_x2
94
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
95
+ def test_add (x1 , x2 ):
96
96
sanity_check (x1 , x2 )
97
97
a = xp .add (x1 , x2 )
98
98
@@ -133,9 +133,8 @@ def test_atan(x):
133
133
# mapped to nan, which is already tested in the special cases.
134
134
ah .assert_exactly_equal (domain , codomain )
135
135
136
- @given (hh .two_mutual_arrays (hh .floating_dtype_objects ))
137
- def test_atan2 (x1_and_x2 ):
138
- x1 , x2 = x1_and_x2
136
+ @given (* hh .two_mutual_arrays (hh .floating_dtype_objects ))
137
+ def test_atan2 (x1 , x2 ):
139
138
sanity_check (x1 , x2 )
140
139
a = xp .atan2 (x1 , x2 )
141
140
INFINITY1 = ah .infinity (x1 .shape , x1 .dtype )
@@ -181,10 +180,9 @@ def test_atanh(x):
181
180
# mapped to nan, which is already tested in the special cases.
182
181
ah .assert_exactly_equal (domain , codomain )
183
182
184
- @given (hh .two_mutual_arrays (ah .integer_or_boolean_dtype_objects ))
185
- def test_bitwise_and (x1_and_x2 ):
183
+ @given (* hh .two_mutual_arrays (ah .integer_or_boolean_dtype_objects ))
184
+ def test_bitwise_and (x1 , x2 ):
186
185
from .test_type_promotion import dtype_nbits , dtype_signed
187
- x1 , x2 = x1_and_x2
188
186
sanity_check (x1 , x2 )
189
187
out = xp .bitwise_and (x1 , x2 )
190
188
@@ -211,10 +209,9 @@ def test_bitwise_and(x1_and_x2):
211
209
assert vals_and == res
212
210
213
211
214
- @given (hh .two_mutual_arrays (ah .integer_dtype_objects ))
215
- def test_bitwise_left_shift (x1_and_x2 ):
212
+ @given (* hh .two_mutual_arrays (ah .integer_dtype_objects ))
213
+ def test_bitwise_left_shift (x1 , x2 ):
216
214
from .test_type_promotion import dtype_nbits , dtype_signed
217
- x1 , x2 = x1_and_x2
218
215
sanity_check (x1 , x2 )
219
216
assume (not ah .any (ah .isnegative (x2 )))
220
217
out = xp .bitwise_left_shift (x1 , x2 )
@@ -254,10 +251,9 @@ def test_bitwise_invert(x):
254
251
val_invert = ah .int_to_dtype (val_invert , dtype_nbits (out .dtype ), dtype_signed (out .dtype ))
255
252
assert val_invert == res
256
253
257
- @given (hh .two_mutual_arrays (ah .integer_or_boolean_dtype_objects ))
258
- def test_bitwise_or (x1_and_x2 ):
254
+ @given (* hh .two_mutual_arrays (ah .integer_or_boolean_dtype_objects ))
255
+ def test_bitwise_or (x1 , x2 ):
259
256
from .test_type_promotion import dtype_nbits , dtype_signed
260
- x1 , x2 = x1_and_x2
261
257
sanity_check (x1 , x2 )
262
258
out = xp .bitwise_or (x1 , x2 )
263
259
@@ -283,10 +279,9 @@ def test_bitwise_or(x1_and_x2):
283
279
vals_or = ah .int_to_dtype (vals_or , dtype_nbits (out .dtype ), dtype_signed (out .dtype ))
284
280
assert vals_or == res
285
281
286
- @given (hh .two_mutual_arrays (ah .integer_dtype_objects ))
287
- def test_bitwise_right_shift (x1_and_x2 ):
282
+ @given (* hh .two_mutual_arrays (ah .integer_dtype_objects ))
283
+ def test_bitwise_right_shift (x1 , x2 ):
288
284
from .test_type_promotion import dtype_nbits , dtype_signed
289
- x1 , x2 = x1_and_x2
290
285
sanity_check (x1 , x2 )
291
286
assume (not ah .any (ah .isnegative (x2 )))
292
287
out = xp .bitwise_right_shift (x1 , x2 )
@@ -306,10 +301,9 @@ def test_bitwise_right_shift(x1_and_x2):
306
301
vals_shift = ah .int_to_dtype (vals_shift , dtype_nbits (out .dtype ), dtype_signed (out .dtype ))
307
302
assert vals_shift == res
308
303
309
- @given (hh .two_mutual_arrays (ah .integer_or_boolean_dtype_objects ))
310
- def test_bitwise_xor (x1_and_x2 ):
304
+ @given (* hh .two_mutual_arrays (ah .integer_or_boolean_dtype_objects ))
305
+ def test_bitwise_xor (x1 , x2 ):
311
306
from .test_type_promotion import dtype_nbits , dtype_signed
312
- x1 , x2 = x1_and_x2
313
307
sanity_check (x1 , x2 )
314
308
out = xp .bitwise_xor (x1 , x2 )
315
309
@@ -367,9 +361,8 @@ def test_cosh(x):
367
361
# mapped to nan, which is already tested in the special cases.
368
362
ah .assert_exactly_equal (domain , codomain )
369
363
370
- @given (hh .two_mutual_arrays (hh .floating_dtype_objects ))
371
- def test_divide (x1_and_x2 ):
372
- x1 , x2 = x1_and_x2
364
+ @given (* hh .two_mutual_arrays (hh .floating_dtype_objects ))
365
+ def test_divide (x1 , x2 ):
373
366
sanity_check (x1 , x2 )
374
367
xp .divide (x1 , x2 )
375
368
# There isn't much we can test here. The spec doesn't require any behavior
@@ -379,9 +372,8 @@ def test_divide(x1_and_x2):
379
372
# have those sorts in general for this module.
380
373
381
374
382
- @given (hh .two_mutual_arrays ())
383
- def test_equal (x1_and_x2 ):
384
- x1 , x2 = x1_and_x2
375
+ @given (* hh .two_mutual_arrays ())
376
+ def test_equal (x1 , x2 ):
385
377
sanity_check (x1 , x2 )
386
378
a = ah .equal (x1 , x2 )
387
379
# NOTE: ah.assert_exactly_equal() itself uses ah.equal(), so we must be careful
@@ -461,9 +453,8 @@ def test_floor(x):
461
453
integers = ah .isintegral (x )
462
454
ah .assert_exactly_equal (a [integers ], x [integers ])
463
455
464
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
465
- def test_floor_divide (x1_and_x2 ):
466
- x1 , x2 = x1_and_x2
456
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
457
+ def test_floor_divide (x1 , x2 ):
467
458
sanity_check (x1 , x2 )
468
459
if ah .is_integer_dtype (x1 .dtype ):
469
460
# The spec does not specify the behavior for division by 0 for integer
@@ -486,9 +477,8 @@ def test_floor_divide(x1_and_x2):
486
477
487
478
# TODO: Test the exact output for floor_divide.
488
479
489
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
490
- def test_greater (x1_and_x2 ):
491
- x1 , x2 = x1_and_x2
480
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
481
+ def test_greater (x1 , x2 ):
492
482
sanity_check (x1 , x2 )
493
483
a = xp .greater (x1 , x2 )
494
484
@@ -516,9 +506,8 @@ def test_greater(x1_and_x2):
516
506
assert aidx .shape == x1idx .shape == x2idx .shape
517
507
assert bool (aidx ) == (scalar_func (x1idx ) > scalar_func (x2idx ))
518
508
519
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
520
- def test_greater_equal (x1_and_x2 ):
521
- x1 , x2 = x1_and_x2
509
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
510
+ def test_greater_equal (x1 , x2 ):
522
511
sanity_check (x1 , x2 )
523
512
a = xp .greater_equal (x1 , x2 )
524
513
@@ -592,9 +581,8 @@ def test_isnan(x):
592
581
s = float (x [idx ])
593
582
assert bool (a [idx ]) == math .isnan (s )
594
583
595
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
596
- def test_less (x1_and_x2 ):
597
- x1 , x2 = x1_and_x2
584
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
585
+ def test_less (x1 , x2 ):
598
586
sanity_check (x1 , x2 )
599
587
a = ah .less (x1 , x2 )
600
588
@@ -622,9 +610,8 @@ def test_less(x1_and_x2):
622
610
assert aidx .shape == x1idx .shape == x2idx .shape
623
611
assert bool (aidx ) == (scalar_func (x1idx ) < scalar_func (x2idx ))
624
612
625
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
626
- def test_less_equal (x1_and_x2 ):
627
- x1 , x2 = x1_and_x2
613
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
614
+ def test_less_equal (x1 , x2 ):
628
615
sanity_check (x1 , x2 )
629
616
a = ah .less_equal (x1 , x2 )
630
617
@@ -696,18 +683,16 @@ def test_log10(x):
696
683
# mapped to nan, which is already tested in the special cases.
697
684
ah .assert_exactly_equal (domain , codomain )
698
685
699
- @given (hh .two_mutual_arrays (hh .floating_dtype_objects ))
700
- def test_logaddexp (x1_and_x2 ):
701
- x1 , x2 = x1_and_x2
686
+ @given (* hh .two_mutual_arrays (hh .floating_dtype_objects ))
687
+ def test_logaddexp (x1 , x2 ):
702
688
sanity_check (x1 , x2 )
703
689
xp .logaddexp (x1 , x2 )
704
690
# The spec doesn't require any behavior for this function. We could test
705
691
# that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
706
692
# don't have tests for this sort of thing for any functions yet.
707
693
708
- @given (hh .two_mutual_arrays ([xp .bool ]))
709
- def test_logical_and (x1_and_x2 ):
710
- x1 , x2 = x1_and_x2
694
+ @given (* hh .two_mutual_arrays ([xp .bool ]))
695
+ def test_logical_and (x1 , x2 ):
711
696
sanity_check (x1 , x2 )
712
697
a = ah .logical_and (x1 , x2 )
713
698
@@ -726,9 +711,8 @@ def test_logical_not(x):
726
711
for idx in ah .ndindex (x .shape ):
727
712
assert a [idx ] == (not bool (x [idx ]))
728
713
729
- @given (hh .two_mutual_arrays ([xp .bool ]))
730
- def test_logical_or (x1_and_x2 ):
731
- x1 , x2 = x1_and_x2
714
+ @given (* hh .two_mutual_arrays ([xp .bool ]))
715
+ def test_logical_or (x1 , x2 ):
732
716
sanity_check (x1 , x2 )
733
717
a = ah .logical_or (x1 , x2 )
734
718
@@ -740,9 +724,8 @@ def test_logical_or(x1_and_x2):
740
724
for idx in ah .ndindex (shape ):
741
725
assert a [idx ] == (bool (_x1 [idx ]) or bool (_x2 [idx ]))
742
726
743
- @given (hh .two_mutual_arrays ([xp .bool ]))
744
- def test_logical_xor (x1_and_x2 ):
745
- x1 , x2 = x1_and_x2
727
+ @given (* hh .two_mutual_arrays ([xp .bool ]))
728
+ def test_logical_xor (x1 , x2 ):
746
729
sanity_check (x1 , x2 )
747
730
a = xp .logical_xor (x1 , x2 )
748
731
@@ -754,9 +737,8 @@ def test_logical_xor(x1_and_x2):
754
737
for idx in ah .ndindex (shape ):
755
738
assert a [idx ] == (bool (_x1 [idx ]) ^ bool (_x2 [idx ]))
756
739
757
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
758
- def test_multiply (x1_and_x2 ):
759
- x1 , x2 = x1_and_x2
740
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
741
+ def test_multiply (x1 , x2 ):
760
742
sanity_check (x1 , x2 )
761
743
a = xp .multiply (x1 , x2 )
762
744
@@ -784,9 +766,8 @@ def test_negative(x):
784
766
ah .assert_exactly_equal (y , ZERO )
785
767
786
768
787
- @given (hh .two_mutual_arrays ())
788
- def test_not_equal (x1_and_x2 ):
789
- x1 , x2 = x1_and_x2
769
+ @given (* hh .two_mutual_arrays ())
770
+ def test_not_equal (x1 , x2 ):
790
771
sanity_check (x1 , x2 )
791
772
a = xp .not_equal (x1 , x2 )
792
773
@@ -821,9 +802,8 @@ def test_positive(x):
821
802
# Positive does nothing
822
803
ah .assert_exactly_equal (out , x )
823
804
824
- @given (hh .two_mutual_arrays (hh .floating_dtype_objects ))
825
- def test_pow (x1_and_x2 ):
826
- x1 , x2 = x1_and_x2
805
+ @given (* hh .two_mutual_arrays (hh .floating_dtype_objects ))
806
+ def test_pow (x1 , x2 ):
827
807
sanity_check (x1 , x2 )
828
808
xp .pow (x1 , x2 )
829
809
# There isn't much we can test here. The spec doesn't require any behavior
@@ -832,9 +812,8 @@ def test_pow(x1_and_x2):
832
812
# numbers. We could test that this does implement IEEE 754 pow, but we
833
813
# don't yet have those sorts in general for this module.
834
814
835
- @given (hh .two_mutual_arrays (hh .numeric_dtype_objects ))
836
- def test_remainder (x1_and_x2 ):
837
- x1 , x2 = x1_and_x2
815
+ @given (* hh .two_mutual_arrays (hh .numeric_dtype_objects ))
816
+ def test_remainder (x1 , x2 ):
838
817
assume (len (x1 .shape ) <= len (x2 .shape )) # TODO: rework same sign testing below to remove this
839
818
sanity_check (x1 , x2 )
840
819
out = xp .remainder (x1 , x2 )
0 commit comments