|
5 | 5 | if TYPE_CHECKING:
|
6 | 6 | from ._typing import Array, ModuleType
|
7 | 7 |
|
8 |
| -__all__ = ["atleast_nd"] |
| 8 | +__all__ = ["atleast_nd", "expand_dims", "kron"] |
9 | 9 |
|
10 | 10 |
|
11 |
| -def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array: |
| 11 | +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: |
12 | 12 | """
|
13 | 13 | Recursively expand the dimension of an array to at least `ndim`.
|
14 | 14 |
|
@@ -46,3 +46,193 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
|
46 | 46 | x = xp.expand_dims(x, axis=0)
|
47 | 47 | x = atleast_nd(x, ndim=ndim, xp=xp)
|
48 | 48 | return x
|
| 49 | + |
| 50 | + |
| 51 | +def expand_dims( |
| 52 | + a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType |
| 53 | +) -> Array: |
| 54 | + """ |
| 55 | + Expand the shape of an array. |
| 56 | +
|
| 57 | + Insert (a) new axis/axes that will appear at the position(s) specified by |
| 58 | + `axis` in the expanded array shape. |
| 59 | +
|
| 60 | + This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*. |
| 61 | + Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays. |
| 62 | +
|
| 63 | + Parameters |
| 64 | + ---------- |
| 65 | + a : array |
| 66 | + axis : int or tuple of ints, optional |
| 67 | + Position(s) in the expanded axes where the new axis (or axes) is/are placed. |
| 68 | + If multiple positions are provided, they should be unique (note that a position |
| 69 | + given by a positive index could also be referred to by a negative index - |
| 70 | + that will also result in an error). |
| 71 | + Default: ``(0,)``. |
| 72 | + xp : array_namespace |
| 73 | + The standard-compatible namespace for `a`. |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + res : array |
| 78 | + `a` with an expanded shape. |
| 79 | +
|
| 80 | + Examples |
| 81 | + -------- |
| 82 | + >>> import array_api_strict as xp |
| 83 | + >>> import array_api_extra as xpx |
| 84 | + >>> x = xp.asarray([1, 2]) |
| 85 | + >>> x.shape |
| 86 | + (2,) |
| 87 | +
|
| 88 | + The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``: |
| 89 | +
|
| 90 | + >>> y = xpx.expand_dims(x, axis=0, xp=xp) |
| 91 | + >>> y |
| 92 | + Array([[1, 2]], dtype=array_api_strict.int64) |
| 93 | + >>> y.shape |
| 94 | + (1, 2) |
| 95 | +
|
| 96 | + The following is equivalent to ``x[:, xp.newaxis]``: |
| 97 | +
|
| 98 | + >>> y = xpx.expand_dims(x, axis=1, xp=xp) |
| 99 | + >>> y |
| 100 | + Array([[1], |
| 101 | + [2]], dtype=array_api_strict.int64) |
| 102 | + >>> y.shape |
| 103 | + (2, 1) |
| 104 | +
|
| 105 | + ``axis`` may also be a tuple: |
| 106 | +
|
| 107 | + >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) |
| 108 | + >>> y |
| 109 | + Array([[[1, 2]]], dtype=array_api_strict.int64) |
| 110 | +
|
| 111 | + >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) |
| 112 | + >>> y |
| 113 | + Array([[[1], |
| 114 | + [2]]], dtype=array_api_strict.int64) |
| 115 | +
|
| 116 | + """ |
| 117 | + if not isinstance(axis, tuple): |
| 118 | + axis = (axis,) |
| 119 | + ndim = a.ndim + len(axis) |
| 120 | + if axis != () and (min(axis) < -ndim or max(axis) >= ndim): |
| 121 | + err_msg = ( |
| 122 | + f"a provided axis position is out of bounds for array of dimension {a.ndim}" |
| 123 | + ) |
| 124 | + raise IndexError(err_msg) |
| 125 | + axis = tuple(dim % ndim for dim in axis) |
| 126 | + if len(set(axis)) != len(axis): |
| 127 | + err_msg = "Duplicate dimensions specified in `axis`." |
| 128 | + raise ValueError(err_msg) |
| 129 | + for i in sorted(axis): |
| 130 | + a = xp.expand_dims(a, axis=i) |
| 131 | + return a |
| 132 | + |
| 133 | + |
| 134 | +def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: |
| 135 | + """ |
| 136 | + Kronecker product of two arrays. |
| 137 | +
|
| 138 | + Computes the Kronecker product, a composite array made of blocks of the |
| 139 | + second array scaled by the first. |
| 140 | +
|
| 141 | + Equivalent to ``numpy.kron`` for NumPy arrays. |
| 142 | +
|
| 143 | + Parameters |
| 144 | + ---------- |
| 145 | + a, b : array |
| 146 | + xp : array_namespace |
| 147 | + The standard-compatible namespace for `a` and `b`. |
| 148 | +
|
| 149 | + Returns |
| 150 | + ------- |
| 151 | + res : array |
| 152 | + The Kronecker product of `a` and `b`. |
| 153 | +
|
| 154 | + Notes |
| 155 | + ----- |
| 156 | + The function assumes that the number of dimensions of `a` and `b` |
| 157 | + are the same, if necessary prepending the smallest with ones. |
| 158 | + If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``, |
| 159 | + the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``. |
| 160 | + The elements are products of elements from `a` and `b`, organized |
| 161 | + explicitly by:: |
| 162 | +
|
| 163 | + kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN] |
| 164 | +
|
| 165 | + where:: |
| 166 | +
|
| 167 | + kt = it * st + jt, t = 0,...,N |
| 168 | +
|
| 169 | + In the common 2-D case (N=1), the block structure can be visualized:: |
| 170 | +
|
| 171 | + [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], |
| 172 | + [ ... ... ], |
| 173 | + [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] |
| 174 | +
|
| 175 | +
|
| 176 | + Examples |
| 177 | + -------- |
| 178 | + >>> import array_api_strict as xp |
| 179 | + >>> import array_api_extra as xpx |
| 180 | + >>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp) |
| 181 | + Array([ 5, 6, 7, 50, 60, 70, 500, |
| 182 | + 600, 700], dtype=array_api_strict.int64) |
| 183 | +
|
| 184 | + >>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp) |
| 185 | + Array([ 5, 50, 500, 6, 60, 600, 7, |
| 186 | + 70, 700], dtype=array_api_strict.int64) |
| 187 | +
|
| 188 | + >>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp) |
| 189 | + Array([[1., 1., 0., 0.], |
| 190 | + [1., 1., 0., 0.], |
| 191 | + [0., 0., 1., 1.], |
| 192 | + [0., 0., 1., 1.]], dtype=array_api_strict.float64) |
| 193 | +
|
| 194 | +
|
| 195 | + >>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5)) |
| 196 | + >>> b = xp.reshape(xp.arange(24), (2, 3, 4)) |
| 197 | + >>> c = xpx.kron(a, b, xp=xp) |
| 198 | + >>> c.shape |
| 199 | + (2, 10, 6, 20) |
| 200 | + >>> I = (1, 3, 0, 2) |
| 201 | + >>> J = (0, 2, 1) |
| 202 | + >>> J1 = (0,) + J # extend to ndim=4 |
| 203 | + >>> S1 = (1,) + b.shape |
| 204 | + >>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1)) |
| 205 | + >>> c[K] == a[I]*b[J] |
| 206 | + Array(True, dtype=array_api_strict.bool) |
| 207 | +
|
| 208 | + """ |
| 209 | + |
| 210 | + b = xp.asarray(b) |
| 211 | + singletons = (1,) * (b.ndim - a.ndim) |
| 212 | + a = xp.broadcast_to(xp.asarray(a), singletons + a.shape) |
| 213 | + |
| 214 | + nd_b, nd_a = b.ndim, a.ndim |
| 215 | + nd_max = max(nd_b, nd_a) |
| 216 | + if nd_a == 0 or nd_b == 0: |
| 217 | + return xp.multiply(a, b) |
| 218 | + |
| 219 | + a_shape = a.shape |
| 220 | + b_shape = b.shape |
| 221 | + |
| 222 | + # Equalise the shapes by prepending smaller one with 1s |
| 223 | + a_shape = (1,) * max(0, nd_b - nd_a) + a_shape |
| 224 | + b_shape = (1,) * max(0, nd_a - nd_b) + b_shape |
| 225 | + |
| 226 | + # Insert empty dimensions |
| 227 | + a_arr = expand_dims(a, axis=tuple(range(nd_b - nd_a)), xp=xp) |
| 228 | + b_arr = expand_dims(b, axis=tuple(range(nd_a - nd_b)), xp=xp) |
| 229 | + |
| 230 | + # Compute the product |
| 231 | + a_arr = expand_dims(a_arr, axis=tuple(range(1, nd_max * 2, 2)), xp=xp) |
| 232 | + b_arr = expand_dims(b_arr, axis=tuple(range(0, nd_max * 2, 2)), xp=xp) |
| 233 | + result = xp.multiply(a_arr, b_arr) |
| 234 | + |
| 235 | + # Reshape back and return |
| 236 | + a_shape = xp.asarray(a_shape) |
| 237 | + b_shape = xp.asarray(b_shape) |
| 238 | + return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape))) |
0 commit comments