Skip to content

Commit f524872

Browse files
committed
sh.fmt_idx() helper
1 parent bee123a commit f524872

File tree

3 files changed

+64
-4
lines changed

3 files changed

+64
-4
lines changed

array_api_tests/meta/test_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,22 @@ def test_axes_ndindex(shape, axes, expected):
8282
)
8383
def test_roll_ndindex(shape, shifts, axes, expected):
8484
assert list(roll_ndindex(shape, shifts, axes)) == expected
85+
86+
87+
@pytest.mark.parametrize(
88+
"idx, expected",
89+
[
90+
((), "x"),
91+
(42, "x[42]"),
92+
((42,), "x[42]"),
93+
(slice(None, 2), "x[:2]"),
94+
(slice(2, None), "x[2:]"),
95+
(slice(0, 2), "x[0:2]"),
96+
(slice(0, 2, -1), "x[0:2:-1]"),
97+
(slice(None, None, -1), "x[::-1]"),
98+
(slice(None, None), "x[:]"),
99+
(..., "x[...]"),
100+
],
101+
)
102+
def test_fmt_idx(idx, expected):
103+
assert sh.fmt_idx("x", idx) == expected

array_api_tests/shape_helpers.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,16 @@
44

55
from ndindex import iter_indices as _iter_indices
66

7-
from .typing import Scalar, Shape
7+
from .typing import AtomicIndex, Index, Scalar, Shape
88

9-
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]
9+
__all__ = [
10+
"normalise_axis",
11+
"ndindex",
12+
"axis_ndindex",
13+
"axes_ndindex",
14+
"reshape",
15+
"fmt_idx",
16+
]
1017

1118

1219
def normalise_axis(
@@ -64,7 +71,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
6471
yield list(indices)
6572

6673

67-
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
74+
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]:
6875
"""Reshape a flat sequence"""
6976
if any(s == 0 for s in shape):
7077
raise ValueError(
@@ -79,3 +86,33 @@ def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]
7986
size = len(flat_seq)
8087
n = math.prod(shape[1:])
8188
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
89+
90+
91+
def fmt_i(i: AtomicIndex) -> str:
92+
if isinstance(i, int):
93+
return str(i)
94+
elif isinstance(i, slice):
95+
res = ""
96+
if i.start is not None:
97+
res += str(i.start)
98+
res += ":"
99+
if i.stop is not None:
100+
res += str(i.stop)
101+
if i.step is not None:
102+
res += f":{i.step}"
103+
return res
104+
else:
105+
return "..."
106+
107+
108+
def fmt_idx(sym: str, idx: Index) -> str:
109+
if idx == ():
110+
return sym
111+
res = f"{sym}["
112+
_idx = idx if isinstance(idx, tuple) else (idx,)
113+
if len(_idx) == 1:
114+
res += fmt_i(_idx[0])
115+
else:
116+
res += ", ".join(fmt_i(i) for i in _idx)
117+
res += "]"
118+
return res

array_api_tests/typing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import Tuple, Type, Union, Any
1+
from typing import Any, Tuple, Type, Union
22

33
__all__ = [
44
"DataType",
55
"Scalar",
66
"ScalarType",
77
"Array",
88
"Shape",
9+
"AtomicIndex",
10+
"Index",
911
"Param",
1012
]
1113

@@ -14,4 +16,6 @@
1416
ScalarType = Union[Type[bool], Type[int], Type[float]]
1517
Array = Any
1618
Shape = Tuple[int, ...]
19+
AtomicIndex = Union[int, "ellipsis", slice] # noqa
20+
Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]]
1721
Param = Tuple

0 commit comments

Comments
 (0)