Skip to content

Commit baf84f2

Browse files
lucascolleytupui
andcommitted
ENH: add atleast_nd
Co-authored-by: Pamphile Roy <[email protected]>
1 parent f2a28cb commit baf84f2

File tree

6 files changed

+848
-3
lines changed

6 files changed

+848
-3
lines changed

pixi.lock

+783-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,20 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = []
29+
dependencies = [
30+
"array-api-compat",
31+
]
3032

3133
[project.optional-dependencies]
3234
test = [
3335
"pytest >=6",
3436
"pytest-cov >=3",
37+
"array-api-strict",
3538
]
3639
dev = [
3740
"pytest >=6",
3841
"pytest-cov >=3",
42+
"array-api-strict",
3943
"pylint",
4044
]
4145
docs = [
@@ -81,6 +85,7 @@ lint = { depends-on = ["pre-commit", "pylint"] }
8185
[tool.pixi.feature.test.dependencies]
8286
pytest = ">=6"
8387
pytest-cov = ">=3"
88+
array-api-strict = "*"
8489

8590
[tool.pixi.feature.test.tasks]
8691
test = { cmd = "pytest" }

src/array_api_extra/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from ._funcs import atleast_nd
4+
35
__version__ = "0.1.dev0"
46

5-
__all__ = ["__version__"]
7+
__all__ = ["__version__", "atleast_nd"]

src/array_api_extra/_funcs.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from array_api_compat import array_namespace # type: ignore[import-not-found]
6+
7+
if TYPE_CHECKING:
8+
from ._typing import Array, ModuleType
9+
10+
__all__ = ["atleast_nd"]
11+
12+
13+
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
14+
"""
15+
Recursively expand the dimension of an array to have at least `ndim`.
16+
17+
Parameters
18+
----------
19+
x: array
20+
An array.
21+
22+
Returns
23+
-------
24+
res: array
25+
An array with ``res.ndim`` >= `ndim`.
26+
If ``x.ndim`` >= `ndim`, `x` is returned.
27+
If ``x.ndim`` < `ndim`, ``res.ndim`` will equal `ndim`.
28+
"""
29+
xp = array_namespace(x) if xp is None else xp
30+
31+
x = xp.asarray(x)
32+
if x.ndim < ndim:
33+
x = xp.expand_dims(x, axis=0)
34+
x = atleast_nd(x, ndim=ndim, xp=xp)
35+
return x

src/array_api_extra/_typing.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
from types import ModuleType
4+
from typing import TYPE_CHECKING, Any
5+
6+
if TYPE_CHECKING:
7+
Array = Any # To be changed to a Protocol later (see array-api#589)
8+
9+
__all__ = ["Array", "ModuleType"]

tests/test_funcs.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
import array_api_strict as xp # type: ignore[import-not-found]
4+
5+
from array_api_extra import atleast_nd
6+
7+
8+
class TestAtLeastND:
9+
def test_1d_to_2d(self):
10+
x = xp.asarray([0, 1])
11+
y = atleast_nd(x, ndim=2, xp=xp)
12+
assert y.ndim == 2

0 commit comments

Comments
 (0)