Skip to content

Commit eaf1e52

Browse files
committed
Rename get_namespace to array_namespace
get_namespace is maintained as a backwards compatible alias. Fixes #19.
1 parent 902fefd commit eaf1e52

File tree

6 files changed

+65
-51
lines changed

6 files changed

+65
-51
lines changed

CHANGELOG.md

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# 1.1.1 (2023-03-08)
22

3+
## Major Changes
4+
5+
- Rename `get_namespace()` to `array_namespace()` (`get_namespace()` is
6+
maintained as a backwards compatible alias).
7+
38
## Minor Changes
49

510
- The minimum supported NumPy version is now 1.21. Fixed a few issues with
@@ -8,11 +13,14 @@
813

914
- Add `api_version` to `get_namespace()`.
1015

11-
- `get_namespace()` now works correctly with `torch` tensors.
16+
- `array_namespace()` (*née* `get_namespace()`) now works correctly with
17+
`torch` tensors.
1218

13-
- `get_namespace()` now works correctly with `numpy.array_api` arrays.
19+
- `array_namespace()` (*née* `get_namespace()`) now works correctly with
20+
`numpy.array_api` arrays.
1421

15-
- `get_namespace()` now raises `TypeError` instead of `ValueError`.
22+
- `array_namespace()` (*née* `get_namespace()`) now raises `TypeError` instead
23+
of `ValueError`.
1624

1725
- Fix the `torch.std` wrapper.
1826

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ later this year.
2121
## Usage
2222

2323
The typical usage of this library will be to get the corresponding array API
24-
compliant namespace from the input arrays using `get_namespace()`, like
24+
compliant namespace from the input arrays using `array_namespace()`, like
2525

2626
```py
2727
def your_function(x, y):
28-
xp = array_api_compat.get_namespace(x, y)
28+
xp = array_api_compat.array_namespace(x, y)
2929
# Now use xp as the array library namespace
3030
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
3131
```
@@ -88,7 +88,7 @@ part of the specification but which are useful for using the array API:
8888
- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
8989
object.
9090

91-
- `get_namespace(*xs)`: Get the corresponding array API namespace for the
91+
- `array_namespace(*xs)`: Get the corresponding array API namespace for the
9292
arrays `xs`. For example, if the arrays are NumPy arrays, the returned
9393
namespace will be `array_api_compat.numpy`. Note that this function will
9494
also work for namespaces that aren't supported by this compat library but
@@ -133,7 +133,7 @@ specification:
133133
don't want to monkeypatch or wrap it. The helper functions `device()` and
134134
`to_device()` are provided to work around these missing methods (see above).
135135
`x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`.
136-
`get_namespace(x)` should be used instead of `x.__array_namespace__`.
136+
`array_namespace(x)` should be used instead of `x.__array_namespace__`.
137137

138138
- Value-based casting for scalars will be in effect unless explicitly disabled
139139
with the environment variable `NPY_PROMOTION_STATE=weak` or
@@ -168,7 +168,7 @@ version.
168168

169169
- Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the
170170
`__array_namespace__` and `to_device` methods, so the corresponding helper
171-
functions `get_namespace()` and `to_device()` in this library should be
171+
functions `array_namespace()` and `to_device()` in this library should be
172172
used instead (see above).
173173

174174
- The `x.size` attribute on `torch.Tensor` is a function that behaves

array_api_compat/common/_aliases.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from types import ModuleType
1414
import inspect
1515

16-
from ._helpers import _check_device, _is_numpy_array, get_namespace
16+
from ._helpers import _check_device, _is_numpy_array, array_namespace
1717

1818
# These functions are modified from the NumPy versions.
1919

@@ -293,7 +293,7 @@ def _asarray(
293293
"""
294294
if namespace is None:
295295
try:
296-
xp = get_namespace(obj, _use_compat=False)
296+
xp = array_namespace(obj, _use_compat=False)
297297
except ValueError:
298298
# TODO: What about lists of arrays?
299299
raise ValueError("A namespace must be specified for asarray() with non-array input")

array_api_compat/common/_helpers.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _check_api_version(api_version):
5353
if api_version is not None and api_version != '2021.12':
5454
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")
5555

56-
def get_namespace(*xs, api_version=None, _use_compat=True):
56+
def array_namespace(*xs, api_version=None, _use_compat=True):
5757
"""
5858
Get the array API compatible namespace for the arrays `xs`.
5959
@@ -62,7 +62,7 @@ def get_namespace(*xs, api_version=None, _use_compat=True):
6262
Typical usage is
6363
6464
def your_function(x, y):
65-
xp = array_api_compat.get_namespace(x, y)
65+
xp = array_api_compat.array_namespace(x, y)
6666
# Now use xp as the array library namespace
6767
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
6868
@@ -72,7 +72,7 @@ def your_function(x, y):
7272
namespaces = set()
7373
for x in xs:
7474
if isinstance(x, (tuple, list)):
75-
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
75+
namespaces.add(array_namespace(*x, _use_compat=_use_compat))
7676
elif hasattr(x, '__array_namespace__'):
7777
namespaces.add(x.__array_namespace__(api_version=api_version))
7878
elif _is_numpy_array(x):
@@ -113,6 +113,8 @@ def your_function(x, y):
113113

114114
return xp
115115

116+
# backwards compatibility alias
117+
get_namespace = array_namespace
116118

117119
def _check_device(xp, device):
118120
if xp == sys.modules.get('numpy'):
@@ -224,4 +226,4 @@ def size(x):
224226
return None
225227
return math.prod(x.shape)
226228

227-
__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device', 'size']
229+
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']

tests/test_array_namespace.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import array_api_compat
2+
from array_api_compat import array_namespace
3+
import pytest
4+
5+
6+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
7+
@pytest.mark.parametrize("api_version", [None, '2021.12'])
8+
def test_array_namespace(library, api_version):
9+
lib = pytest.importorskip(library)
10+
11+
array = lib.asarray([1.0, 2.0, 3.0])
12+
namespace = array_api_compat.array_namespace(array, api_version=api_version)
13+
14+
if 'array_api' in library:
15+
assert namespace == lib
16+
else:
17+
assert namespace == getattr(array_api_compat, library)
18+
19+
def test_array_namespace_multiple():
20+
import numpy as np
21+
22+
x = np.asarray([1, 2])
23+
assert array_namespace(x, x) == array_namespace((x, x)) == \
24+
array_namespace((x, x), x) == array_api_compat.numpy
25+
26+
def test_array_namespace_errors():
27+
pytest.raises(TypeError, lambda: array_namespace([1]))
28+
pytest.raises(TypeError, lambda: array_namespace())
29+
30+
import numpy as np
31+
import torch
32+
x = np.asarray([1, 2])
33+
y = torch.asarray([1, 2])
34+
35+
pytest.raises(TypeError, lambda: array_namespace(x, y))
36+
37+
pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12'))
38+
39+
def test_get_namespace():
40+
# Backwards compatible wrapper
41+
assert array_api_compat.get_namespace is array_api_compat.array_namespace

tests/test_get_namespace.py

-37
This file was deleted.

0 commit comments

Comments
 (0)