5
5
6
6
import math
7
7
import warnings
8
- from collections .abc import Sequence
8
+ from collections .abc import Generator , Sequence
9
9
from types import ModuleType
10
10
from typing import cast
11
11
@@ -163,6 +163,16 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
163
163
return xp .squeeze (c , axis = axes )
164
164
165
165
166
+ def ndindex (* x : int ) -> Generator [tuple [int , ...]]:
167
+ if not x :
168
+ yield ()
169
+ return
170
+ indices = list (ndindex (* x [1 :]))
171
+ for i in range (x [0 ]):
172
+ for j in indices :
173
+ yield i , * j
174
+
175
+
166
176
def create_diagonal (
167
177
x : Array , / , * , offset : int = 0 , xp : ModuleType | None = None
168
178
) -> Array :
@@ -172,7 +182,7 @@ def create_diagonal(
172
182
Parameters
173
183
----------
174
184
x : array
175
- A 1-D array.
185
+ An array having shape (*broadcast_dims, k) .
176
186
offset : int, optional
177
187
Offset from the leading diagonal (default is ``0``).
178
188
Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +193,8 @@ def create_diagonal(
183
193
Returns
184
194
-------
185
195
array
186
- A 2-D array with `x` on the diagonal (offset by `offset`).
196
+ An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
197
+ on the diagonal (offset by `offset`).
187
198
188
199
Examples
189
200
--------
@@ -206,18 +217,21 @@ def create_diagonal(
206
217
if xp is None :
207
218
xp = array_namespace (x )
208
219
209
- if x .ndim != 1 :
210
- err_msg = "`x` must be 1-dimensional."
220
+ if x .ndim == 0 :
221
+ err_msg = "`x` must be at least 1-dimensional."
211
222
raise ValueError (err_msg )
212
- n = x .shape [0 ] + abs (offset )
213
- diag = xp .zeros (n ** 2 , dtype = x .dtype , device = _compat .device (x ))
214
-
215
- start = offset if offset >= 0 else abs (offset ) * n
216
- stop = min (n * (n - offset ), diag .shape [0 ])
217
- step = n + 1
218
- diag = at (diag )[start :stop :step ].set (x )
219
-
220
- return xp .reshape (diag , (n , n ))
223
+ pre = x .shape [:- 1 ]
224
+ n = x .shape [- 1 ] + abs (offset )
225
+ diag = xp .zeros ((* pre , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
226
+
227
+ target_slice = slice (
228
+ offset if offset >= 0 else abs (offset ) * n ,
229
+ min (n * (n - offset ), diag .shape [- 1 ]),
230
+ n + 1 ,
231
+ )
232
+ for index in ndindex (* pre ):
233
+ diag = at (diag )[(* index , target_slice )].set (x [* index , :])
234
+ return xp .reshape (diag , (* pre , n , n ))
221
235
222
236
223
237
def expand_dims (
0 commit comments