Skip to content

Commit 4b851ef

Browse files
committed
Remove some use of array_helpers.py which just mirrored xp funcs
The uses which weren't directly imported, i.e. used the alias `ah`
1 parent 7b5e3ab commit 4b851ef

File tree

3 files changed

+23
-25
lines changed

3 files changed

+23
-25
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Dict, Optional, Sequence, Tuple, Union
44

55
from . import _array_module as xp
6-
from . import array_helpers as ah
76
from . import dtype_helpers as dh
87
from . import shape_helpers as sh
98
from . import stubs
@@ -370,9 +369,9 @@ def assert_fill(
370369
"""
371370
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
372371
if math.isnan(fill_value):
373-
assert ah.all(ah.isnan(out)), msg
372+
assert xp.all(xp.isnan(out)), msg
374373
else:
375-
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
374+
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
376375

377376

378377
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):

array_api_tests/test_creation_functions.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from hypothesis import strategies as st
88

99
from . import _array_module as xp
10-
from . import array_helpers as ah
1110
from . import dtype_helpers as dh
1211
from . import hypothesis_helpers as hh
1312
from . import pytest_helpers as ph
@@ -181,12 +180,12 @@ def test_arange(dtype, data):
181180
if dh.is_int_dtype(_dtype):
182181
elements = list(r)
183182
assume(out.size == len(elements))
184-
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))
183+
ph.assert_array("arange", out, xp.asarray(elements, dtype=_dtype))
185184
else:
186185
assume(out.size == size)
187186
if out.size > 0:
188-
assert ah.equal(
189-
out[0], ah.asarray(_start, dtype=out.dtype)
187+
assert xp.equal(
188+
out[0], xp.asarray(_start, dtype=out.dtype)
190189
), f"out[0]={out[0]}, but should be {_start} {f_func}"
191190

192191

@@ -421,8 +420,8 @@ def test_linspace(num, dtype, endpoint, data):
421420
start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start")
422421
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
423422
# avoid overflow errors
424-
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
425-
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))
423+
assume(not xp.isnan(xp.asarray(stop - start, dtype=_dtype)))
424+
assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype)))
426425

427426
kw = data.draw(
428427
hh.specified_kwargs(
@@ -440,20 +439,20 @@ def test_linspace(num, dtype, endpoint, data):
440439
ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
441440
f_func = f"[linspace({start}, {stop}, {num})]"
442441
if num > 0:
443-
assert ah.equal(
444-
out[0], ah.asarray(start, dtype=out.dtype)
442+
assert xp.equal(
443+
out[0], xp.asarray(start, dtype=out.dtype)
445444
), f"out[0]={out[0]}, but should be {start} {f_func}"
446445
if endpoint:
447446
if num > 1:
448-
assert ah.equal(
449-
out[-1], ah.asarray(stop, dtype=out.dtype)
447+
assert xp.equal(
448+
out[-1], xp.asarray(stop, dtype=out.dtype)
450449
), f"out[-1]={out[-1]}, but should be {stop} {f_func}"
451450
else:
452451
# linspace(..., num, endpoint=True) should return an array equivalent to
453452
# the first num elements when endpoint=False
454453
expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
455454
expected = expected[:-1]
456-
ah.assert_exactly_equal(out, expected)
455+
ph.assert_array("linspace", out, expected)
457456

458457

459458
@given(dtype=xps.numeric_dtypes(), data=st.data())

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array:
409409

410410
def func(l: Array, r: Union[Scalar, Array]) -> Array:
411411
locals_ = {}
412-
locals_[left_sym] = ah.asarray(l, copy=True) # prevents mutating l
412+
locals_[left_sym] = xp.asarray(l, copy=True) # prevents mutating l
413413
locals_[right_sym] = r
414414
exec(expr, locals_)
415415
return locals_[left_sym]
@@ -659,7 +659,7 @@ def test_bitwise_left_shift(ctx, data):
659659
if ctx.right_is_scalar:
660660
assume(right >= 0)
661661
else:
662-
assume(not ah.any(ah.isnegative(right)))
662+
assume(not xp.any(ah.isnegative(right)))
663663

664664
res = ctx.func(left, right)
665665

@@ -718,7 +718,7 @@ def test_bitwise_right_shift(ctx, data):
718718
if ctx.right_is_scalar:
719719
assume(right >= 0)
720720
else:
721-
assume(not ah.any(ah.isnegative(right)))
721+
assume(not xp.any(ah.isnegative(right)))
722722

723723
res = ctx.func(left, right)
724724

@@ -851,13 +851,13 @@ def test_floor(x):
851851
@given(data=st.data())
852852
def test_floor_divide(ctx, data):
853853
left = data.draw(
854-
ctx.left_strat.filter(lambda x: not ah.any(x == 0)), label=ctx.left_sym
854+
ctx.left_strat.filter(lambda x: not xp.any(x == 0)), label=ctx.left_sym
855855
)
856856
right = data.draw(ctx.right_strat, label=ctx.right_sym)
857857
if ctx.right_is_scalar:
858858
assume(right != 0)
859859
else:
860-
assume(not ah.any(right == 0))
860+
assume(not xp.any(right == 0))
861861

862862
res = ctx.func(left, right)
863863

@@ -908,7 +908,7 @@ def test_greater_equal(ctx, data):
908908

909909
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
910910
def test_isfinite(x):
911-
out = ah.isfinite(x)
911+
out = xp.isfinite(x)
912912
ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool)
913913
ph.assert_shape("isfinite", out.shape, x.shape)
914914
unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool)
@@ -924,7 +924,7 @@ def test_isinf(x):
924924

925925
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
926926
def test_isnan(x):
927-
out = ah.isnan(x)
927+
out = xp.isnan(x)
928928
ph.assert_dtype("isnan", x.dtype, out.dtype, xp.bool)
929929
ph.assert_shape("isnan", out.shape, x.shape)
930930
unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool)
@@ -1024,7 +1024,7 @@ def test_logaddexp(x1, x2):
10241024

10251025
@given(*hh.two_mutual_arrays([xp.bool]))
10261026
def test_logical_and(x1, x2):
1027-
out = ah.logical_and(x1, x2)
1027+
out = xp.logical_and(x1, x2)
10281028
ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype)
10291029
ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape)
10301030
binary_assert_against_refimpl(
@@ -1034,7 +1034,7 @@ def test_logical_and(x1, x2):
10341034

10351035
@given(xps.arrays(dtype=xp.bool, shape=hh.shapes()))
10361036
def test_logical_not(x):
1037-
out = ah.logical_not(x)
1037+
out = xp.logical_not(x)
10381038
ph.assert_dtype("logical_not", x.dtype, out.dtype)
10391039
ph.assert_shape("logical_not", out.shape, x.shape)
10401040
unary_assert_against_refimpl(
@@ -1044,7 +1044,7 @@ def test_logical_not(x):
10441044

10451045
@given(*hh.two_mutual_arrays([xp.bool]))
10461046
def test_logical_or(x1, x2):
1047-
out = ah.logical_or(x1, x2)
1047+
out = xp.logical_or(x1, x2)
10481048
ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype)
10491049
ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape)
10501050
binary_assert_against_refimpl(
@@ -1157,7 +1157,7 @@ def test_remainder(ctx, data):
11571157
if ctx.right_is_scalar:
11581158
assume(right != 0)
11591159
else:
1160-
assume(not ah.any(right == 0))
1160+
assume(not xp.any(right == 0))
11611161

11621162
res = ctx.func(left, right)
11631163

0 commit comments

Comments
 (0)