Skip to content

Commit c22efdf

Browse files
committed
Move broadcast_shapes() to shape_helpers.py
1 parent 84fcf98 commit c22efdf

12 files changed

+70
-26
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from ._array_module import _UndefinedStub
1717
from ._array_module import bool as bool_dtype
1818
from ._array_module import broadcast_to, eye, float32, float64, full
19-
from .algos import broadcast_shapes
2019
from .function_stubs import elementwise_functions
2120
from .pytest_helpers import nargs
2221
from .typing import Array, DataType, Shape
@@ -243,7 +242,7 @@ def two_broadcastable_shapes(draw):
243242
broadcast to shape1.
244243
"""
245244
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
246-
assume(broadcast_shapes(shape1, shape2) == shape1)
245+
assume(sh.broadcast_shapes(shape1, shape2) == shape1)
247246
return (shape1, shape2)
248247

249248
sizes = integers(0, MAX_ARRAY_SIZE)

array_api_tests/meta/test_broadcasting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from ..algos import BroadcastError, _broadcast_shapes
7+
from .. import shape_helpers as sh
88

99

1010
@pytest.mark.parametrize(
@@ -19,7 +19,7 @@
1919
],
2020
)
2121
def test_broadcast_shapes(shape1, shape2, expected):
22-
assert _broadcast_shapes(shape1, shape2) == expected
22+
assert sh._broadcast_shapes(shape1, shape2) == expected
2323

2424

2525
@pytest.mark.parametrize(
@@ -31,5 +31,5 @@ def test_broadcast_shapes(shape1, shape2, expected):
3131
],
3232
)
3333
def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
34-
with pytest.raises(BroadcastError):
35-
_broadcast_shapes(shape1, shape2)
34+
with pytest.raises(sh.BroadcastError):
35+
sh._broadcast_shapes(shape1, shape2)

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from .. import array_helpers as ah
99
from .. import dtype_helpers as dh
1010
from .. import hypothesis_helpers as hh
11+
from .. import shape_helpers as sh
1112
from .. import xps
1213
from .._array_module import _UndefinedStub
13-
from ..algos import broadcast_shapes
1414

1515
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1616
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
@@ -62,7 +62,7 @@ def test_two_mutually_broadcastable_shapes(pair):
6262
def test_two_broadcastable_shapes(pair):
6363
for shape in pair:
6464
assert valid_shape(shape)
65-
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
65+
assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0]
6666

6767

6868
@given(*hh.two_mutual_arrays())

array_api_tests/pytest_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from . import array_helpers as ah
77
from . import dtype_helpers as dh
88
from . import function_stubs
9-
from .algos import broadcast_shapes
9+
from . import shape_helpers as sh
1010
from .typing import Array, DataType, Scalar, ScalarType, Shape
1111

1212
__all__ = [
@@ -159,7 +159,7 @@ def assert_result_shape(
159159
**kw,
160160
):
161161
if expected is None:
162-
expected = broadcast_shapes(*in_shapes)
162+
expected = sh.broadcast_shapes(*in_shapes)
163163
f_in_shapes = " . ".join(str(s) for s in in_shapes)
164164
f_sig = f" {f_in_shapes} "
165165
if kw:

array_api_tests/shape_helpers.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .typing import AtomicIndex, Index, Scalar, Shape
99

1010
__all__ = [
11+
"broadcast_shapes",
1112
"normalise_axis",
1213
"ndindex",
1314
"axis_ndindex",
@@ -17,6 +18,54 @@
1718
]
1819

1920

21+
class BroadcastError(ValueError):
22+
"""Shapes do not broadcast with eachother"""
23+
24+
25+
def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
26+
"""Broadcasts `shape1` and `shape2`"""
27+
N1 = len(shape1)
28+
N2 = len(shape2)
29+
N = max(N1, N2)
30+
shape = [None for _ in range(N)]
31+
i = N - 1
32+
while i >= 0:
33+
n1 = N1 - N + i
34+
if N1 - N + i >= 0:
35+
d1 = shape1[n1]
36+
else:
37+
d1 = 1
38+
n2 = N2 - N + i
39+
if N2 - N + i >= 0:
40+
d2 = shape2[n2]
41+
else:
42+
d2 = 1
43+
44+
if d1 == 1:
45+
shape[i] = d2
46+
elif d2 == 1:
47+
shape[i] = d1
48+
elif d1 == d2:
49+
shape[i] = d1
50+
else:
51+
raise BroadcastError()
52+
53+
i = i - 1
54+
55+
return tuple(shape)
56+
57+
58+
def broadcast_shapes(*shapes: Shape):
59+
if len(shapes) == 0:
60+
raise ValueError("shapes=[] must be non-empty")
61+
elif len(shapes) == 1:
62+
return shapes[0]
63+
result = _broadcast_shapes(shapes[0], shapes[1])
64+
for i in range(2, len(shapes)):
65+
result = _broadcast_shapes(result, shapes[i])
66+
return result
67+
68+
2069
def normalise_axis(
2170
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
2271
) -> Tuple[int, ...]:

array_api_tests/test_data_type_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from . import dtype_helpers as dh
1010
from . import hypothesis_helpers as hh
1111
from . import pytest_helpers as ph
12+
from . import shape_helpers as sh
1213
from . import xps
13-
from .algos import broadcast_shapes
1414
from .typing import DataType
1515

1616
pytestmark = pytest.mark.ci
@@ -70,7 +70,7 @@ def test_broadcast_arrays(shapes, data):
7070

7171
out = xp.broadcast_arrays(*arrays)
7272

73-
out_shape = broadcast_shapes(*shapes)
73+
out_shape = sh.broadcast_shapes(*shapes)
7474
for i, x in enumerate(arrays):
7575
ph.assert_dtype(
7676
"broadcast_arrays", x.dtype, out[i].dtype, repr_name=f"out[{i}].dtype"
@@ -90,7 +90,7 @@ def test_broadcast_to(x, data):
9090
shape = data.draw(
9191
hh.mutually_broadcastable_shapes(1, base_shape=x.shape)
9292
.map(lambda S: S[0])
93-
.filter(lambda s: broadcast_shapes(x.shape, s) == s),
93+
.filter(lambda s: sh.broadcast_shapes(x.shape, s) == s),
9494
label="shape",
9595
)
9696

array_api_tests/test_linalg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from . import pytest_helpers as ph
3131
from . import shape_helpers as sh
3232

33-
from .algos import broadcast_shapes
34-
3533
from . import _array_module
3634
from . import _array_module as xp
3735
from ._array_module import linalg
@@ -56,7 +54,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw):
5654
if res is None:
5755
res = f(*args, **kw)
5856

59-
shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape
57+
shape = args[0].shape if len(args) == 1 else sh.broadcast_shapes(*[x.shape
6058
for x in args])
6159
for _idx in sh.ndindex(shape[:-2]):
6260
idx = _idx + (slice(None),)*dims
@@ -297,7 +295,7 @@ def test_matmul(x1, x2):
297295
assert res.shape == x1.shape[:-1]
298296
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1)
299297
else:
300-
stack_shape = broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
298+
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
301299
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
302300
_test_stacks(_array_module.matmul, x1, x2, res=res)
303301

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from . import pytest_helpers as ph
2626
from . import shape_helpers as sh
2727
from . import xps
28-
from .algos import broadcast_shapes
2928
from .typing import Array, DataType, Param, Scalar, Shape
3029

3130
pytestmark = pytest.mark.ci
@@ -1223,7 +1222,7 @@ def test_logical_and(x1, x2):
12231222
out = ah.logical_and(x1, x2)
12241223
ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype)
12251224
# See the comments in test_equal
1226-
shape = broadcast_shapes(x1.shape, x2.shape)
1225+
shape = sh.broadcast_shapes(x1.shape, x2.shape)
12271226
ph.assert_shape("logical_and", out.shape, shape)
12281227
_x1 = xp.broadcast_to(x1, shape)
12291228
_x2 = xp.broadcast_to(x2, shape)
@@ -1245,7 +1244,7 @@ def test_logical_or(x1, x2):
12451244
out = ah.logical_or(x1, x2)
12461245
ph.assert_dtype("logical_or", (x1.dtype, x2.dtype), out.dtype)
12471246
# See the comments in test_equal
1248-
shape = broadcast_shapes(x1.shape, x2.shape)
1247+
shape = sh.broadcast_shapes(x1.shape, x2.shape)
12491248
ph.assert_shape("logical_or", out.shape, shape)
12501249
_x1 = xp.broadcast_to(x1, shape)
12511250
_x2 = xp.broadcast_to(x2, shape)
@@ -1258,7 +1257,7 @@ def test_logical_xor(x1, x2):
12581257
out = xp.logical_xor(x1, x2)
12591258
ph.assert_dtype("logical_xor", (x1.dtype, x2.dtype), out.dtype)
12601259
# See the comments in test_equal
1261-
shape = broadcast_shapes(x1.shape, x2.shape)
1260+
shape = sh.broadcast_shapes(x1.shape, x2.shape)
12621261
ph.assert_shape("logical_xor", out.shape, shape)
12631262
_x1 = xp.broadcast_to(x1, shape)
12641263
_x2 = xp.broadcast_to(x2, shape)

array_api_tests/test_searching_functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from . import pytest_helpers as ph
99
from . import shape_helpers as sh
1010
from . import xps
11-
from .algos import broadcast_shapes
1211

1312
pytestmark = pytest.mark.ci
1413

@@ -134,7 +133,7 @@ def test_where(shapes, dtypes, data):
134133

135134
out = xp.where(cond, x1, x2)
136135

137-
shape = broadcast_shapes(*shapes)
136+
shape = sh.broadcast_shapes(*shapes)
138137
ph.assert_shape("where", out.shape, shape)
139138
# TODO: generate indices without broadcasting arrays
140139
_cond = xp.broadcast_to(cond, shape)

array_api_tests/test_set_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# TODO: disable if opted out, refactor things
22
import math
3-
import pytest
43
from collections import Counter, defaultdict
54

5+
import pytest
66
from hypothesis import assume, given
77

88
from . import _array_module as xp

array_api_tests/test_sorting_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
2-
import pytest
32
from typing import Set
43

4+
import pytest
55
from hypothesis import given
66
from hypothesis import strategies as st
77
from hypothesis.control import assume

array_api_tests/test_statistical_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
2-
import pytest
32
from typing import Optional
43

4+
import pytest
55
from hypothesis import assume, given
66
from hypothesis import strategies as st
77
from hypothesis.control import reject

0 commit comments

Comments
 (0)