@@ -58,14 +58,14 @@ def expand_dims(
58
58
`axis` in the expanded array shape.
59
59
60
60
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
61
- Equivalent to ``numpy.expand_dims`` for NumPy arrays.
61
+ Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
62
62
63
63
Parameters
64
64
----------
65
65
a : array
66
66
axis : int or tuple of ints
67
67
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68
- If multiple positions are provided, they should be unique.
68
+ If multiple positions are provided, they should be unique and increasing .
69
69
Default: ``(0,)``.
70
70
xp : array_namespace
71
71
The standard-compatible namespace for `a`.
@@ -77,52 +77,54 @@ def expand_dims(
77
77
78
78
Examples
79
79
--------
80
- # >>> import numpy as np
81
- # >>> x = np.array([1, 2])
82
- # >>> x.shape
83
- # (2,)
84
-
85
- # The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
86
-
87
- # >>> y = np.expand_dims(x, axis=0)
88
- # >>> y
89
- # array([[1, 2]])
90
- # >>> y.shape
91
- # (1, 2)
80
+ >>> import array_api_strict as xp
81
+ >>> import array_api_extra as xpx
82
+ >>> x = xp.asarray([1, 2])
83
+ >>> x.shape
84
+ (2,)
92
85
93
- # The following is equivalent to ``x[:, np .newaxis]``:
86
+ The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp .newaxis]``:
94
87
95
- # >>> y = np.expand_dims(x, axis=1)
96
- # >>> y
97
- # array([[1],
98
- # [2]])
99
- # >>> y.shape
100
- # (2, 1)
88
+ >>> y = xpx.expand_dims(x, axis=0, xp=xp)
89
+ >>> y
90
+ Array([[1, 2]], dtype=array_api_strict.int64)
91
+ >>> y.shape
92
+ (1, 2)
101
93
102
- # ``axis`` may also be a tuple :
94
+ The following is equivalent to ``x[:, xp.newaxis]`` :
103
95
104
- # >>> y = np.expand_dims(x, axis=(0, 1))
105
- # >>> y
106
- # array([[[1, 2]]])
96
+ >>> y = xpx.expand_dims(x, axis=1, xp=xp)
97
+ >>> y
98
+ Array([[1],
99
+ [2]], dtype=array_api_strict.int64)
100
+ >>> y.shape
101
+ (2, 1)
107
102
108
- # >>> y = np.expand_dims(x, axis=(2, 0))
109
- # >>> y
110
- # array([[[1],
111
- # [2]]])
103
+ ``axis`` may also be a tuple:
112
104
113
- # Note that some examples may use ``None`` instead of ``np.newaxis``. These
114
- # are the same objects:
105
+ >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
106
+ >>> y
107
+ Array([[[1, 2]]], dtype=array_api_strict.int64)
115
108
116
- # >>> np.newaxis is None
117
- # True
109
+ >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
110
+ >>> y
111
+ Array([[[1],
112
+ [2]]], dtype=array_api_strict.int64)
118
113
119
114
"""
120
115
if not isinstance (axis , tuple ):
121
116
axis = (axis ,)
122
117
if len (set (axis )) != len (axis ):
123
118
err_msg = "Duplicate dimensions specified in `axis`."
124
119
raise ValueError (err_msg )
125
- for i in axis :
120
+ ndim = a .ndim + len (axis )
121
+ if axis != () and (min (axis ) < - ndim or max (axis ) >= ndim ):
122
+ err_msg = (
123
+ f"a provided axis position is out of bounds for array of dimension { a .ndim } "
124
+ )
125
+ raise IndexError (err_msg )
126
+ axis = tuple (dim % ndim for dim in axis )
127
+ for i in sorted (axis ):
126
128
a = xp .expand_dims (a , axis = i )
127
129
return a
128
130
@@ -145,6 +147,7 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
145
147
Returns
146
148
-------
147
149
res : array
150
+ The Kronecker product of `a` and `b`.
148
151
149
152
Notes
150
153
-----
@@ -170,30 +173,35 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
170
173
171
174
Examples
172
175
--------
173
- # >>> import numpy as np
174
- # >>> np.kron([1,10,100], [5,6,7])
175
- # array([ 5, 6, 7, ..., 500, 600, 700])
176
- # >>> np.kron([5,6,7], [1,10,100])
177
- # array([ 5, 50, 500, ..., 7, 70, 700])
178
-
179
- # >>> np.kron(np.eye(2), np.ones((2,2)))
180
- # array([[1., 1., 0., 0.],
181
- # [1., 1., 0., 0.],
182
- # [0., 0., 1., 1.],
183
- # [0., 0., 1., 1.]])
184
-
185
- # >>> a = np.arange(100).reshape((2,5,2,5))
186
- # >>> b = np.arange(24).reshape((2,3,4))
187
- # >>> c = np.kron(a,b)
188
- # >>> c.shape
189
- # (2, 10, 6, 20)
190
- # >>> I = (1,3,0,2)
191
- # >>> J = (0,2,1)
192
- # >>> J1 = (0,) + J # extend to ndim=4
193
- # >>> S1 = (1,) + b.shape
194
- # >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
195
- # >>> c[K] == a[I]*b[J]
196
- # True
176
+ >>> import array_api_strict as xp
177
+ >>> import array_api_extra as xpx
178
+ >>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
179
+ Array([ 5, 6, 7, 50, 60, 70, 500,
180
+ 600, 700], dtype=array_api_strict.int64)
181
+
182
+ >>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
183
+ Array([ 5, 50, 500, 6, 60, 600, 7,
184
+ 70, 700], dtype=array_api_strict.int64)
185
+
186
+ >>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
187
+ Array([[1., 1., 0., 0.],
188
+ [1., 1., 0., 0.],
189
+ [0., 0., 1., 1.],
190
+ [0., 0., 1., 1.]], dtype=array_api_strict.float64)
191
+
192
+
193
+ >>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
194
+ >>> b = xp.reshape(xp.arange(24), (2, 3, 4))
195
+ >>> c = xpx.kron(a, b, xp=xp)
196
+ >>> c.shape
197
+ (2, 10, 6, 20)
198
+ >>> I = (1, 3, 0, 2)
199
+ >>> J = (0, 2, 1)
200
+ >>> J1 = (0,) + J # extend to ndim=4
201
+ >>> S1 = (1,) + b.shape
202
+ >>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
203
+ >>> c[K] == a[I]*b[J]
204
+ Array(True, dtype=array_api_strict.bool)
197
205
198
206
"""
199
207
0 commit comments