Skip to content

Commit 3388ab4

Browse files
committed
DOC: add examples for kron and expand_dims
1 parent c380728 commit 3388ab4

File tree

2 files changed

+68
-58
lines changed

2 files changed

+68
-58
lines changed

docs/api-reference.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@
77
:toctree: generated
88
99
atleast_nd
10+
expand_dims
11+
kron
1012
```

src/array_api_extra/_funcs.py

+66-58
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ def expand_dims(
5858
`axis` in the expanded array shape.
5959
6060
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
61-
Equivalent to ``numpy.expand_dims`` for NumPy arrays.
61+
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
6262
6363
Parameters
6464
----------
6565
a : array
6666
axis : int or tuple of ints
6767
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68-
If multiple positions are provided, they should be unique.
68+
If multiple positions are provided, they should be unique and increasing.
6969
Default: ``(0,)``.
7070
xp : array_namespace
7171
The standard-compatible namespace for `a`.
@@ -77,52 +77,54 @@ def expand_dims(
7777
7878
Examples
7979
--------
80-
# >>> import numpy as np
81-
# >>> x = np.array([1, 2])
82-
# >>> x.shape
83-
# (2,)
84-
85-
# The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
86-
87-
# >>> y = np.expand_dims(x, axis=0)
88-
# >>> y
89-
# array([[1, 2]])
90-
# >>> y.shape
91-
# (1, 2)
80+
>>> import array_api_strict as xp
81+
>>> import array_api_extra as xpx
82+
>>> x = xp.asarray([1, 2])
83+
>>> x.shape
84+
(2,)
9285
93-
# The following is equivalent to ``x[:, np.newaxis]``:
86+
The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:
9487
95-
# >>> y = np.expand_dims(x, axis=1)
96-
# >>> y
97-
# array([[1],
98-
# [2]])
99-
# >>> y.shape
100-
# (2, 1)
88+
>>> y = xpx.expand_dims(x, axis=0, xp=xp)
89+
>>> y
90+
Array([[1, 2]], dtype=array_api_strict.int64)
91+
>>> y.shape
92+
(1, 2)
10193
102-
# ``axis`` may also be a tuple:
94+
The following is equivalent to ``x[:, xp.newaxis]``:
10395
104-
# >>> y = np.expand_dims(x, axis=(0, 1))
105-
# >>> y
106-
# array([[[1, 2]]])
96+
>>> y = xpx.expand_dims(x, axis=1, xp=xp)
97+
>>> y
98+
Array([[1],
99+
[2]], dtype=array_api_strict.int64)
100+
>>> y.shape
101+
(2, 1)
107102
108-
# >>> y = np.expand_dims(x, axis=(2, 0))
109-
# >>> y
110-
# array([[[1],
111-
# [2]]])
103+
``axis`` may also be a tuple:
112104
113-
# Note that some examples may use ``None`` instead of ``np.newaxis``. These
114-
# are the same objects:
105+
>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
106+
>>> y
107+
Array([[[1, 2]]], dtype=array_api_strict.int64)
115108
116-
# >>> np.newaxis is None
117-
# True
109+
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
110+
>>> y
111+
Array([[[1],
112+
[2]]], dtype=array_api_strict.int64)
118113
119114
"""
120115
if not isinstance(axis, tuple):
121116
axis = (axis,)
122117
if len(set(axis)) != len(axis):
123118
err_msg = "Duplicate dimensions specified in `axis`."
124119
raise ValueError(err_msg)
125-
for i in axis:
120+
ndim = a.ndim + len(axis)
121+
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
122+
err_msg = (
123+
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
124+
)
125+
raise IndexError(err_msg)
126+
axis = tuple(dim % ndim for dim in axis)
127+
for i in sorted(axis):
126128
a = xp.expand_dims(a, axis=i)
127129
return a
128130

@@ -145,6 +147,7 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
145147
Returns
146148
-------
147149
res : array
150+
The Kronecker product of `a` and `b`.
148151
149152
Notes
150153
-----
@@ -170,30 +173,35 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
170173
171174
Examples
172175
--------
173-
# >>> import numpy as np
174-
# >>> np.kron([1,10,100], [5,6,7])
175-
# array([ 5, 6, 7, ..., 500, 600, 700])
176-
# >>> np.kron([5,6,7], [1,10,100])
177-
# array([ 5, 50, 500, ..., 7, 70, 700])
178-
179-
# >>> np.kron(np.eye(2), np.ones((2,2)))
180-
# array([[1., 1., 0., 0.],
181-
# [1., 1., 0., 0.],
182-
# [0., 0., 1., 1.],
183-
# [0., 0., 1., 1.]])
184-
185-
# >>> a = np.arange(100).reshape((2,5,2,5))
186-
# >>> b = np.arange(24).reshape((2,3,4))
187-
# >>> c = np.kron(a,b)
188-
# >>> c.shape
189-
# (2, 10, 6, 20)
190-
# >>> I = (1,3,0,2)
191-
# >>> J = (0,2,1)
192-
# >>> J1 = (0,) + J # extend to ndim=4
193-
# >>> S1 = (1,) + b.shape
194-
# >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
195-
# >>> c[K] == a[I]*b[J]
196-
# True
176+
>>> import array_api_strict as xp
177+
>>> import array_api_extra as xpx
178+
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
179+
Array([ 5, 6, 7, 50, 60, 70, 500,
180+
600, 700], dtype=array_api_strict.int64)
181+
182+
>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
183+
Array([ 5, 50, 500, 6, 60, 600, 7,
184+
70, 700], dtype=array_api_strict.int64)
185+
186+
>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
187+
Array([[1., 1., 0., 0.],
188+
[1., 1., 0., 0.],
189+
[0., 0., 1., 1.],
190+
[0., 0., 1., 1.]], dtype=array_api_strict.float64)
191+
192+
193+
>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
194+
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
195+
>>> c = xpx.kron(a, b, xp=xp)
196+
>>> c.shape
197+
(2, 10, 6, 20)
198+
>>> I = (1, 3, 0, 2)
199+
>>> J = (0, 2, 1)
200+
>>> J1 = (0,) + J # extend to ndim=4
201+
>>> S1 = (1,) + b.shape
202+
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
203+
>>> c[K] == a[I]*b[J]
204+
Array(True, dtype=array_api_strict.bool)
197205
198206
"""
199207

0 commit comments

Comments
 (0)