Skip to content

Commit be06b63

Browse files
authored
Merge pull request #7 from lucascolley/kron
2 parents 13dc1ef + fa397d2 commit be06b63

File tree

6 files changed

+349
-21
lines changed

6 files changed

+349
-21
lines changed

docs/api-reference.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@
77
:toctree: generated
88
99
atleast_nd
10+
expand_dims
11+
kron
1012
```

pixi.lock

+29-15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pylint = "*"
7272
# import dependencies for mypy:
7373
array-api-strict = "*"
7474
numpy = "*"
75+
pytest = "*"
7576

7677
[tool.pixi.feature.lint.tasks]
7778
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }

src/array_api_extra/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd
3+
from ._funcs import atleast_nd, expand_dims, kron
44

55
__version__ = "0.1.2.dev0"
66

7-
__all__ = ["__version__", "atleast_nd"]
7+
__all__ = ["__version__", "atleast_nd", "expand_dims", "kron"]

src/array_api_extra/_funcs.py

+192-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
if TYPE_CHECKING:
66
from ._typing import Array, ModuleType
77

8-
__all__ = ["atleast_nd"]
8+
__all__ = ["atleast_nd", "expand_dims", "kron"]
99

1010

11-
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
11+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
1212
"""
1313
Recursively expand the dimension of an array to at least `ndim`.
1414
@@ -46,3 +46,193 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
4646
x = xp.expand_dims(x, axis=0)
4747
x = atleast_nd(x, ndim=ndim, xp=xp)
4848
return x
49+
50+
51+
def expand_dims(
52+
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
53+
) -> Array:
54+
"""
55+
Expand the shape of an array.
56+
57+
Insert (a) new axis/axes that will appear at the position(s) specified by
58+
`axis` in the expanded array shape.
59+
60+
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
61+
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
62+
63+
Parameters
64+
----------
65+
a : array
66+
axis : int or tuple of ints, optional
67+
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68+
If multiple positions are provided, they should be unique (note that a position
69+
given by a positive index could also be referred to by a negative index -
70+
that will also result in an error).
71+
Default: ``(0,)``.
72+
xp : array_namespace
73+
The standard-compatible namespace for `a`.
74+
75+
Returns
76+
-------
77+
res : array
78+
`a` with an expanded shape.
79+
80+
Examples
81+
--------
82+
>>> import array_api_strict as xp
83+
>>> import array_api_extra as xpx
84+
>>> x = xp.asarray([1, 2])
85+
>>> x.shape
86+
(2,)
87+
88+
The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:
89+
90+
>>> y = xpx.expand_dims(x, axis=0, xp=xp)
91+
>>> y
92+
Array([[1, 2]], dtype=array_api_strict.int64)
93+
>>> y.shape
94+
(1, 2)
95+
96+
The following is equivalent to ``x[:, xp.newaxis]``:
97+
98+
>>> y = xpx.expand_dims(x, axis=1, xp=xp)
99+
>>> y
100+
Array([[1],
101+
[2]], dtype=array_api_strict.int64)
102+
>>> y.shape
103+
(2, 1)
104+
105+
``axis`` may also be a tuple:
106+
107+
>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
108+
>>> y
109+
Array([[[1, 2]]], dtype=array_api_strict.int64)
110+
111+
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
112+
>>> y
113+
Array([[[1],
114+
[2]]], dtype=array_api_strict.int64)
115+
116+
"""
117+
if not isinstance(axis, tuple):
118+
axis = (axis,)
119+
ndim = a.ndim + len(axis)
120+
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
121+
err_msg = (
122+
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
123+
)
124+
raise IndexError(err_msg)
125+
axis = tuple(dim % ndim for dim in axis)
126+
if len(set(axis)) != len(axis):
127+
err_msg = "Duplicate dimensions specified in `axis`."
128+
raise ValueError(err_msg)
129+
for i in sorted(axis):
130+
a = xp.expand_dims(a, axis=i)
131+
return a
132+
133+
134+
def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
135+
"""
136+
Kronecker product of two arrays.
137+
138+
Computes the Kronecker product, a composite array made of blocks of the
139+
second array scaled by the first.
140+
141+
Equivalent to ``numpy.kron`` for NumPy arrays.
142+
143+
Parameters
144+
----------
145+
a, b : array
146+
xp : array_namespace
147+
The standard-compatible namespace for `a` and `b`.
148+
149+
Returns
150+
-------
151+
res : array
152+
The Kronecker product of `a` and `b`.
153+
154+
Notes
155+
-----
156+
The function assumes that the number of dimensions of `a` and `b`
157+
are the same, if necessary prepending the smallest with ones.
158+
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
159+
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
160+
The elements are products of elements from `a` and `b`, organized
161+
explicitly by::
162+
163+
kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
164+
165+
where::
166+
167+
kt = it * st + jt, t = 0,...,N
168+
169+
In the common 2-D case (N=1), the block structure can be visualized::
170+
171+
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
172+
[ ... ... ],
173+
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
174+
175+
176+
Examples
177+
--------
178+
>>> import array_api_strict as xp
179+
>>> import array_api_extra as xpx
180+
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
181+
Array([ 5, 6, 7, 50, 60, 70, 500,
182+
600, 700], dtype=array_api_strict.int64)
183+
184+
>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
185+
Array([ 5, 50, 500, 6, 60, 600, 7,
186+
70, 700], dtype=array_api_strict.int64)
187+
188+
>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
189+
Array([[1., 1., 0., 0.],
190+
[1., 1., 0., 0.],
191+
[0., 0., 1., 1.],
192+
[0., 0., 1., 1.]], dtype=array_api_strict.float64)
193+
194+
195+
>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
196+
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
197+
>>> c = xpx.kron(a, b, xp=xp)
198+
>>> c.shape
199+
(2, 10, 6, 20)
200+
>>> I = (1, 3, 0, 2)
201+
>>> J = (0, 2, 1)
202+
>>> J1 = (0,) + J # extend to ndim=4
203+
>>> S1 = (1,) + b.shape
204+
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
205+
>>> c[K] == a[I]*b[J]
206+
Array(True, dtype=array_api_strict.bool)
207+
208+
"""
209+
210+
b = xp.asarray(b)
211+
singletons = (1,) * (b.ndim - a.ndim)
212+
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)
213+
214+
nd_b, nd_a = b.ndim, a.ndim
215+
nd_max = max(nd_b, nd_a)
216+
if nd_a == 0 or nd_b == 0:
217+
return xp.multiply(a, b)
218+
219+
a_shape = a.shape
220+
b_shape = b.shape
221+
222+
# Equalise the shapes by prepending smaller one with 1s
223+
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
224+
b_shape = (1,) * max(0, nd_a - nd_b) + b_shape
225+
226+
# Insert empty dimensions
227+
a_arr = expand_dims(a, axis=tuple(range(nd_b - nd_a)), xp=xp)
228+
b_arr = expand_dims(b, axis=tuple(range(nd_a - nd_b)), xp=xp)
229+
230+
# Compute the product
231+
a_arr = expand_dims(a_arr, axis=tuple(range(1, nd_max * 2, 2)), xp=xp)
232+
b_arr = expand_dims(b_arr, axis=tuple(range(0, nd_max * 2, 2)), xp=xp)
233+
result = xp.multiply(a_arr, b_arr)
234+
235+
# Reshape back and return
236+
a_shape = xp.asarray(a_shape)
237+
b_shape = xp.asarray(b_shape)
238+
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))

0 commit comments

Comments
 (0)