Skip to content

Commit e9abb15

Browse files
committed
Remove array helper imports which are covered by xp
1 parent 61c5aa0 commit e9abb15

File tree

1 file changed

+61
-68
lines changed

1 file changed

+61
-68
lines changed

array_api_tests/test_elementwise_functions.py

+61-68
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,31 @@
1414
1515
"""
1616

17-
from hypothesis import given, assume
18-
from hypothesis.strategies import composite, just
19-
2017
import math
2118

22-
from .hypothesis_helpers import (integer_dtype_objects,
23-
floating_dtype_objects,
24-
numeric_dtype_objects,
19+
from hypothesis import assume, given
20+
from hypothesis import strategies as st
21+
22+
from . import _array_module as xp
23+
from .array_helpers import (assert_exactly_equal, assert_integral,
24+
assert_same_sign, dtype_ranges, false, infinity,
25+
inrange, int_to_dtype, is_float_dtype,
26+
is_integer_dtype, isintegral, isnegative, ndindex,
27+
negative_mathematical_sign, one,
28+
positive_mathematical_sign, promote_dtypes, true,
29+
zero, π)
30+
from .hypothesis_helpers import (array_scalars, boolean_dtype_objects,
31+
boolean_dtypes, floating_dtype_objects,
32+
floating_dtypes, integer_dtype_objects,
2533
integer_or_boolean_dtype_objects,
26-
boolean_dtype_objects, floating_dtypes,
27-
numeric_dtypes, integer_or_boolean_dtypes,
28-
boolean_dtypes, mutually_promotable_dtypes,
29-
array_scalars, two_mutual_arrays, xps, shapes)
30-
from .array_helpers import (assert_exactly_equal, negative,
31-
positive_mathematical_sign,
32-
negative_mathematical_sign, logical_not,
33-
logical_or, logical_and, inrange, π, one, zero,
34-
infinity, isnegative, all as array_all, any as
35-
array_any, int_to_dtype,
36-
assert_integral, less_equal, isintegral, isfinite,
37-
ndindex, promote_dtypes, is_integer_dtype,
38-
is_float_dtype, not_equal, asarray,
39-
dtype_ranges, full, true, false, assert_same_sign,
40-
isnan, equal, less)
34+
integer_or_boolean_dtypes,
35+
mutually_promotable_dtypes,
36+
numeric_dtype_objects, numeric_dtypes, shapes,
37+
two_mutual_arrays, xps)
4138
# We might as well use this implementation rather than requiring
4239
# mod.broadcast_shapes(). See test_equal() and others.
4340
from .test_broadcasting import broadcast_shapes
4441

45-
from . import _array_module as xp
46-
4742
# integer_scalars = array_scalars(integer_dtypes)
4843
floating_scalars = array_scalars(floating_dtypes)
4944
numeric_scalars = array_scalars(numeric_dtypes)
@@ -57,11 +52,11 @@
5752
two_boolean_dtypes = mutually_promotable_dtypes(boolean_dtype_objects)
5853
two_any_dtypes = mutually_promotable_dtypes()
5954

60-
@composite
55+
@st.composite
6156
def two_array_scalars(draw, dtype1, dtype2):
6257
# two_dtypes should be a strategy that returns two dtypes (like
6358
# mutually_promotable_dtypes())
64-
return draw(array_scalars(just(dtype1))), draw(array_scalars(just(dtype2)))
59+
return draw(array_scalars(st.just(dtype1))), draw(array_scalars(st.just(dtype2)))
6560

6661
def sanity_check(x1, x2):
6762
try:
@@ -74,17 +69,17 @@ def test_abs(x):
7469
if is_integer_dtype(x.dtype):
7570
minval = dtype_ranges[x.dtype][0]
7671
if minval < 0:
77-
# abs of the smallest representable negative integer is not defined
78-
mask = not_equal(x, full(x.shape, minval, dtype=x.dtype))
72+
# abs of the smallest representable xp.negative integer is not defined
73+
mask = xp.not_equal(x, xp.full(x.shape, minval, dtype=x.dtype))
7974
x = x[mask]
8075
a = xp.abs(x)
81-
assert array_all(logical_not(negative_mathematical_sign(a))), "abs(x) did not have positive sign"
76+
assert xp.all(xp.logical_not(negative_mathematical_sign(a))), "abs(x) did not have positive sign"
8277
less_zero = negative_mathematical_sign(x)
83-
negx = negative(x)
78+
negx = xp.negative(x)
8479
# abs(x) = -x for x < 0
8580
assert_exactly_equal(a[less_zero], negx[less_zero])
8681
# abs(x) = x for x >= 0
87-
assert_exactly_equal(a[logical_not(less_zero)], x[logical_not(less_zero)])
82+
assert_exactly_equal(a[xp.logical_not(less_zero)], x[xp.logical_not(less_zero)])
8883

8984
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=shapes))
9085
def test_acos(x):
@@ -170,7 +165,7 @@ def test_atan2(x1_and_x2):
170165
# atan2 maps [-inf, inf] x [-inf, inf] to [-pi, pi]. Values outside
171166
# this domain are mapped to nan, which is already tested in the special
172167
# cases.
173-
assert_exactly_equal(logical_and(domainx1, domainx2), codomain)
168+
assert_exactly_equal(xp.logical_and(domainx1, domainx2), codomain)
174169
# From the spec:
175170
#
176171
# The mathematical signs of `x1_i` and `x2_i` determine the quadrant of
@@ -187,10 +182,10 @@ def test_atan2(x1_and_x2):
187182
negx2 = negative_mathematical_sign(x2)
188183
posa = positive_mathematical_sign(a)
189184
nega = negative_mathematical_sign(a)
190-
assert_exactly_equal(logical_or(logical_and(posx1, posx2),
191-
logical_and(posx1, negx2)), posa)
192-
assert_exactly_equal(logical_or(logical_and(negx1, posx2),
193-
logical_and(negx1, negx2)), nega)
185+
assert_exactly_equal(xp.logical_or(xp.logical_and(posx1, posx2),
186+
xp.logical_and(posx1, negx2)), posa)
187+
assert_exactly_equal(xp.logical_or(xp.logical_and(negx1, posx2),
188+
xp.logical_and(negx1, negx2)), nega)
194189

195190
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=shapes))
196191
def test_atanh(x):
@@ -232,8 +227,7 @@ def test_bitwise_left_shift(args):
232227
x1, x2 = args
233228
sanity_check(x1, x2)
234229
negative_x2 = isnegative(x2)
235-
if array_any(negative_x2):
236-
assume(False)
230+
assume(not xp.any(negative_x2))
237231
a = xp.bitwise_left_shift(x1, x2)
238232
# Compare against the Python << operator.
239233
# TODO: Generalize this properly for inputs that are arrays.
@@ -296,8 +290,7 @@ def test_bitwise_right_shift(args):
296290
x1, x2 = args
297291
sanity_check(x1, x2)
298292
negative_x2 = isnegative(x2)
299-
if array_any(negative_x2):
300-
assume(False)
293+
assume(not xp.any(negative_x2))
301294
a = xp.bitwise_right_shift(x1, x2)
302295
# Compare against the Python >> operator.
303296
# TODO: Generalize this properly for inputs that are arrays.
@@ -335,10 +328,10 @@ def test_bitwise_xor(args):
335328
def test_ceil(x):
336329
# This test is almost identical to test_floor()
337330
a = xp.ceil(x)
338-
finite = isfinite(x)
331+
finite = xp.isfinite(x)
339332
assert_integral(a[finite])
340-
assert array_all(less_equal(x[finite], a[finite]))
341-
assert array_all(less_equal(a[finite] - x[finite], one(x[finite].shape, x.dtype)))
333+
assert xp.all(xp.less_equal(x[finite], a[finite]))
334+
assert xp.all(xp.less_equal(a[finite] - x[finite], one(x[finite].shape, x.dtype)))
342335
integers = isintegral(x)
343336
assert_exactly_equal(a[integers], x[integers])
344337

@@ -380,7 +373,7 @@ def test_equal(x1_and_x2):
380373
x1, x2 = x1_and_x2
381374
sanity_check(x1, x2)
382375
a = xp.equal(x1, x2)
383-
# NOTE: assert_exactly_equal() itself uses equal(), so we must be careful
376+
# NOTE: assert_exactly_equal() itself uses xp.equal(), so we must be careful
384377
# not to use it here. Otherwise, the test would be circular and
385378
# meaningless. Instead, we implement this by iterating every element of
386379
# the arrays and comparing them. The logic here is also used for the tests
@@ -397,15 +390,15 @@ def test_equal(x1_and_x2):
397390
_x2 = xp.broadcast_to(x2, shape)
398391

399392
# Second, manually promote the dtypes. This is important. If the internal
400-
# type promotion in equal() is wrong, it will not be directly visible in
393+
# type promotion in xp.equal() is wrong, it will not be directly visible in
401394
# the output type, but it can lead to wrong answers. For example,
402-
# equal(array(1.0, dtype=float32), array(1.00000001, dtype=xp.float64)) will
403-
# be wrong if the xp.float64 is downcast to float32. # be wrong if the
395+
# xp.equal(array(1.0, dtype=xp.float32), array(1.00000001, dtype=xp.float64)) will
396+
# be wrong if the float64 is downcast to float32. # be wrong if the
404397
# xp.float64 is downcast to float32. See the comment on
405398
# test_elementwise_function_two_arg_bool_type_promotion() in
406-
# test_type_promotion.py. The type promotion for equal() is not *really*
399+
# test_type_promotion.py. The type promotion for xp.equal() is not *really*
407400
# tested in that file, because doing so requires doing the consistency
408-
# check we do here rather than just checking the result dtype.
401+
# check we do here rather than st.just checking the result dtype.
409402
promoted_dtype = promote_dtypes(x1.dtype, x2.dtype)
410403
_x1 = xp.asarray(_x1, dtype=promoted_dtype)
411404
_x2 = xp.asarray(_x2, dtype=promoted_dtype)
@@ -450,10 +443,10 @@ def test_expm1(x):
450443
def test_floor(x):
451444
# This test is almost identical to test_ceil
452445
a = xp.floor(x)
453-
finite = isfinite(x)
446+
finite = xp.isfinite(x)
454447
assert_integral(a[finite])
455-
assert array_all(less_equal(a[finite], x[finite]))
456-
assert array_all(less_equal(x[finite] - a[finite], one(x[finite].shape, x.dtype)))
448+
assert xp.all(xp.less_equal(a[finite], x[finite]))
449+
assert xp.all(xp.less_equal(x[finite] - a[finite], one(x[finite].shape, x.dtype)))
457450
integers = isintegral(x)
458451
assert_exactly_equal(a[integers], x[integers])
459452

@@ -467,8 +460,8 @@ def test_floor_divide(x1_and_x2):
467460
# we avoid passing it in entirely.
468461
assume(not xp.any(x1 == 0) and not xp.any(x2 == 0))
469462
div = xp.divide(
470-
asarray(x1, dtype=xp.float64),
471-
asarray(x2, dtype=xp.float64),
463+
xp.asarray(x1, dtype=xp.float64),
464+
xp.asarray(x2, dtype=xp.float64),
472465
)
473466
else:
474467
div = xp.divide(x1, x2)
@@ -477,7 +470,7 @@ def test_floor_divide(x1_and_x2):
477470

478471
# TODO: The spec doesn't clearly specify the behavior of floor_divide on
479472
# infinities. See https://github.com/data-apis/array-api/issues/199.
480-
finite = isfinite(div)
473+
finite = xp.isfinite(div)
481474
assert_integral(out[finite])
482475

483476
# TODO: Test the exact output for floor_divide.
@@ -548,9 +541,9 @@ def test_isfinite(x):
548541
TRUE = true(x.shape)
549542
if is_integer_dtype(x.dtype):
550543
assert_exactly_equal(a, TRUE)
551-
# Test that isfinite, isinf, and isnan are self-consistent.
552-
inf = logical_or(xp.isinf(x), xp.isnan(x))
553-
assert_exactly_equal(a, logical_not(inf))
544+
# Test that xp.isfinite, isinf, and xp.isnan are self-consistent.
545+
inf = xp.logical_or(xp.isinf(x), xp.isnan(x))
546+
assert_exactly_equal(a, xp.logical_not(inf))
554547

555548
# Test the exact value by comparing to the math version
556549
if is_float_dtype(x.dtype):
@@ -564,8 +557,8 @@ def test_isinf(x):
564557
FALSE = false(x.shape)
565558
if is_integer_dtype(x.dtype):
566559
assert_exactly_equal(a, FALSE)
567-
finite_or_nan = logical_or(xp.isfinite(x), xp.isnan(x))
568-
assert_exactly_equal(a, logical_not(finite_or_nan))
560+
finite_or_nan = xp.logical_or(xp.isfinite(x), xp.isnan(x))
561+
assert_exactly_equal(a, xp.logical_not(finite_or_nan))
569562

570563
# Test the exact value by comparing to the math version
571564
if is_float_dtype(x.dtype):
@@ -579,8 +572,8 @@ def test_isnan(x):
579572
FALSE = false(x.shape)
580573
if is_integer_dtype(x.dtype):
581574
assert_exactly_equal(a, FALSE)
582-
finite_or_inf = logical_or(xp.isfinite(x), xp.isinf(x))
583-
assert_exactly_equal(a, logical_not(finite_or_inf))
575+
finite_or_inf = xp.logical_or(xp.isfinite(x), xp.isinf(x))
576+
assert_exactly_equal(a, xp.logical_not(finite_or_inf))
584577

585578
# Test the exact value by comparing to the math version
586579
if is_float_dtype(x.dtype):
@@ -767,12 +760,12 @@ def test_negative(x):
767760
# Negation is an involution
768761
assert_exactly_equal(x, xp.negative(out))
769762

770-
mask = isfinite(x)
763+
mask = xp.isfinite(x)
771764
if is_integer_dtype(x.dtype):
772765
minval = dtype_ranges[x.dtype][0]
773766
if minval < 0:
774767
# negative of the smallest representable negative integer is not defined
775-
mask = not_equal(x, full(x.shape, minval, dtype=x.dtype))
768+
mask = xp.not_equal(x, xp.full(x.shape, minval, dtype=x.dtype))
776769

777770
# Additive inverse
778771
y = xp.add(x[mask], out[mask])
@@ -837,15 +830,15 @@ def test_remainder(x1_and_x2):
837830

838831
# out and x2 should have the same sign.
839832
# assert_same_sign returns False for nans
840-
not_nan = logical_not(logical_or(isnan(out), isnan(x2)))
833+
not_nan = xp.logical_not(xp.logical_or(xp.isnan(out), xp.isnan(x2)))
841834
assert_same_sign(out[not_nan], x2[not_nan])
842835

843836
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=shapes))
844837
def test_round(x):
845838
a = xp.round(x)
846839

847840
# Test that the result is integral
848-
finite = isfinite(x)
841+
finite = xp.isfinite(x)
849842
assert_integral(a[finite])
850843

851844
# round(x) should be the nearest integer to x. The case where there is a
@@ -858,8 +851,8 @@ def test_round(x):
858851
ceil = xp.ceil(x)
859852
over = xp.subtract(x, floor)
860853
under = xp.subtract(ceil, x)
861-
round_down = less(over, under)
862-
round_up = less(under, over)
854+
round_down = xp.less(over, under)
855+
round_up = xp.less(under, over)
863856
assert_exactly_equal(a[round_down], floor[round_down])
864857
assert_exactly_equal(a[round_up], ceil[round_up])
865858

@@ -910,7 +903,7 @@ def test_trunc(x):
910903
assert out.dtype == x.dtype, f"{x.dtype=!s} but {out.dtype=!s}"
911904
assert out.shape == x.shape, f"{x.shape} but {out.shape}"
912905
if x.dtype in integer_dtype_objects:
913-
assert array_all(equal(x, out)), f"{x=!s} but {out=!s}"
906+
assert xp.all(xp.equal(x, out)), f"{x=!s} but {out=!s}"
914907
else:
915908
finite_mask = xp.isfinite(out)
916909
for idx in ndindex(out.shape):

0 commit comments

Comments
 (0)