2
2
from itertools import product
3
3
from typing import Iterator , List , Optional , Tuple , Union
4
4
5
- from . typing import Scalar , Shape
5
+ from ndindex import iter_indices as _iter_indices
6
6
7
- __all__ = ["normalise_axis" , "ndindex" , "axis_ndindex" , "axes_ndindex" , "reshape" ]
7
+ from .typing import AtomicIndex , Index , Scalar , Shape
8
+
9
+ __all__ = [
10
+ "broadcast_shapes" ,
11
+ "normalise_axis" ,
12
+ "ndindex" ,
13
+ "axis_ndindex" ,
14
+ "axes_ndindex" ,
15
+ "reshape" ,
16
+ "fmt_idx" ,
17
+ ]
18
+
19
+
20
+ class BroadcastError (ValueError ):
21
+ """Shapes do not broadcast with eachother"""
22
+
23
+
24
+ def _broadcast_shapes (shape1 : Shape , shape2 : Shape ) -> Shape :
25
+ """Broadcasts `shape1` and `shape2`"""
26
+ N1 = len (shape1 )
27
+ N2 = len (shape2 )
28
+ N = max (N1 , N2 )
29
+ shape = [None for _ in range (N )]
30
+ i = N - 1
31
+ while i >= 0 :
32
+ n1 = N1 - N + i
33
+ if N1 - N + i >= 0 :
34
+ d1 = shape1 [n1 ]
35
+ else :
36
+ d1 = 1
37
+ n2 = N2 - N + i
38
+ if N2 - N + i >= 0 :
39
+ d2 = shape2 [n2 ]
40
+ else :
41
+ d2 = 1
42
+
43
+ if d1 == 1 :
44
+ shape [i ] = d2
45
+ elif d2 == 1 :
46
+ shape [i ] = d1
47
+ elif d1 == d2 :
48
+ shape [i ] = d1
49
+ else :
50
+ raise BroadcastError ()
51
+
52
+ i = i - 1
53
+
54
+ return tuple (shape )
55
+
56
+
57
+ def broadcast_shapes (* shapes : Shape ):
58
+ if len (shapes ) == 0 :
59
+ raise ValueError ("shapes=[] must be non-empty" )
60
+ elif len (shapes ) == 1 :
61
+ return shapes [0 ]
62
+ result = _broadcast_shapes (shapes [0 ], shapes [1 ])
63
+ for i in range (2 , len (shapes )):
64
+ result = _broadcast_shapes (result , shapes [i ])
65
+ return result
8
66
9
67
10
68
def normalise_axis (
@@ -17,13 +75,21 @@ def normalise_axis(
17
75
return axes
18
76
19
77
20
- def ndindex (shape ):
21
- """Iterator of n-D indices to an array
78
+ def ndindex (shape : Shape ) -> Iterator [Index ]:
79
+ """Yield every index of a shape"""
80
+ return (indices [0 ] for indices in iter_indices (shape ))
81
+
22
82
23
- Yields tuples of integers to index every element of an array of shape
24
- `shape`. Same as np.ndindex().
25
- """
26
- return product (* [range (i ) for i in shape ])
83
+ def iter_indices (
84
+ * shapes : Shape , skip_axes : Tuple [int , ...] = ()
85
+ ) -> Iterator [Tuple [Index , ...]]:
86
+ """Wrapper for ndindex.iter_indices()"""
87
+ # Prevent iterations if any shape has 0-sides
88
+ for shape in shapes :
89
+ if 0 in shape :
90
+ return
91
+ for indices in _iter_indices (* shapes , skip_axes = skip_axes ):
92
+ yield tuple (i .raw for i in indices ) # type: ignore
27
93
28
94
29
95
def axis_ndindex (
@@ -60,7 +126,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
60
126
yield list (indices )
61
127
62
128
63
- def reshape (flat_seq : List [Scalar ], shape : Shape ) -> Union [Scalar , List [ Scalar ] ]:
129
+ def reshape (flat_seq : List [Scalar ], shape : Shape ) -> Union [Scalar , List ]:
64
130
"""Reshape a flat sequence"""
65
131
if any (s == 0 for s in shape ):
66
132
raise ValueError (
@@ -75,3 +141,33 @@ def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]
75
141
size = len (flat_seq )
76
142
n = math .prod (shape [1 :])
77
143
return [reshape (flat_seq [i * n : (i + 1 ) * n ], shape [1 :]) for i in range (size // n )]
144
+
145
+
146
+ def fmt_i (i : AtomicIndex ) -> str :
147
+ if isinstance (i , int ):
148
+ return str (i )
149
+ elif isinstance (i , slice ):
150
+ res = ""
151
+ if i .start is not None :
152
+ res += str (i .start )
153
+ res += ":"
154
+ if i .stop is not None :
155
+ res += str (i .stop )
156
+ if i .step is not None :
157
+ res += f":{ i .step } "
158
+ return res
159
+ else :
160
+ return "..."
161
+
162
+
163
+ def fmt_idx (sym : str , idx : Index ) -> str :
164
+ if idx == ():
165
+ return sym
166
+ res = f"{ sym } ["
167
+ _idx = idx if isinstance (idx , tuple ) else (idx ,)
168
+ if len (_idx ) == 1 :
169
+ res += fmt_i (_idx [0 ])
170
+ else :
171
+ res += ", " .join (fmt_i (i ) for i in _idx )
172
+ res += "]"
173
+ return res
0 commit comments