Skip to content

Commit 311d0aa

Browse files
authored
Merge pull request #119 from asmeurer/asarray-copy
Support the copy keyword in asarray
2 parents ecb4c57 + 2dcd864 commit 311d0aa

File tree

18 files changed

+316
-136
lines changed

18 files changed

+316
-136
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
runs-on: ubuntu-latest
3838
strategy:
3939
matrix:
40-
python-version: ['3.8', '3.9', '3.10', '3.11']
40+
python-version: ['3.9', '3.10', '3.11', '3.12']
4141

4242
steps:
4343
- name: Checkout array-api-compat
@@ -55,16 +55,15 @@ jobs:
5555
with:
5656
python-version: ${{ matrix.python-version }}
5757
- name: Install dependencies
58-
# NumPy 1.21 doesn't support Python 3.11. NumPy 2.0 doesn't support
59-
# Python 3.8. There doesn't seem to be a way to put this in the numpy
60-
# 1.21 config file.
61-
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
58+
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
59+
# to put this in the numpy 1.21 config file.
60+
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
6261
run: |
6362
python -m pip install --upgrade pip
6463
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }}
6564
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
6665
- name: Run the array API testsuite (${{ inputs.package-name }})
67-
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
66+
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
6867
env:
6968
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
7069
# This enables the NEP 50 type promotion behavior (without it a lot of

.github/workflows/docs-build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Docs Build
33
on: [push, pull_request]
44

55
jobs:
6-
build:
6+
docs-build:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v4

.github/workflows/docs-deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
- main
77

88
jobs:
9-
deploy:
9+
docs-deploy:
1010
runs-on: ubuntu-latest
1111
environment:
1212
name: docs-deploy

.github/workflows/tests.yml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ jobs:
55
runs-on: ubuntu-latest
66
strategy:
77
matrix:
8-
python-version: ['3.8', '3.9', '3.10', '3.11']
8+
python-version: ['3.9', '3.10', '3.11', '3.12']
9+
numpy-version: ['1.21', '1.26', 'dev']
10+
exclude:
11+
- python-version: '3.11'
12+
numpy-version: '1.21'
13+
- python-version: '3.12'
14+
numpy-version: '1.21'
915
fail-fast: true
1016
steps:
1117
- uses: actions/checkout@v4
@@ -15,11 +21,21 @@ jobs:
1521
- name: Install Dependencies
1622
run: |
1723
python -m pip install --upgrade pip
18-
python -m pip install pytest numpy torch dask[array] jax[cpu]
24+
if [ "${{ matrix.numpy-version }}" == "dev" ]; then
25+
PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple'
26+
elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then
27+
PIP_EXTRA='numpy==1.21.*'
28+
else
29+
PIP_EXTRA='numpy==1.26.*'
30+
fi
31+
python -m pip install -r requirements-dev.txt $PIP_EXTRA
1932
2033
- name: Run Tests
2134
run: |
22-
pytest
35+
if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then
36+
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask")
37+
fi
38+
pytest -v "${PYTEST_EXTRA[@]}"
2339
2440
# Make sure it installs
25-
python setup.py install
41+
python -m pip install .

array_api_compat/common/_aliases.py

Lines changed: 4 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
import numpy as np
109
from typing import Optional, Sequence, Tuple, Union
11-
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
10+
from ._typing import ndarray, Device, Dtype
1211

1312
from typing import NamedTuple
14-
from types import ModuleType
1513
import inspect
1614

17-
from ._helpers import _check_device, is_numpy_array, array_namespace
15+
from ._helpers import _check_device
1816

1917
# These functions are modified from the NumPy versions.
2018

19+
# Creation functions add the device keyword (which does nothing for NumPy)
20+
2121
def arange(
2222
start: Union[int, float],
2323
/,
@@ -268,90 +268,6 @@ def var(
268268
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269269
return xp.transpose(x, axes)
270270

271-
# Creation functions add the device keyword (which does nothing for NumPy)
272-
273-
# asarray also adds the copy keyword
274-
def _asarray(
275-
obj: Union[
276-
ndarray,
277-
bool,
278-
int,
279-
float,
280-
NestedSequence[bool | int | float],
281-
SupportsBufferProtocol,
282-
],
283-
/,
284-
*,
285-
dtype: Optional[Dtype] = None,
286-
device: Optional[Device] = None,
287-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
288-
namespace = None,
289-
**kwargs,
290-
) -> ndarray:
291-
"""
292-
Array API compatibility wrapper for asarray().
293-
294-
See the corresponding documentation in NumPy/CuPy and/or the array API
295-
specification for more details.
296-
297-
"""
298-
if namespace is None:
299-
try:
300-
xp = array_namespace(obj, _use_compat=False)
301-
except ValueError:
302-
# TODO: What about lists of arrays?
303-
raise ValueError("A namespace must be specified for asarray() with non-array input")
304-
elif isinstance(namespace, ModuleType):
305-
xp = namespace
306-
elif namespace == 'numpy':
307-
import numpy as xp
308-
elif namespace == 'cupy':
309-
import cupy as xp
310-
elif namespace == 'dask.array':
311-
import dask.array as xp
312-
else:
313-
raise ValueError("Unrecognized namespace argument to asarray()")
314-
315-
_check_device(xp, device)
316-
if is_numpy_array(obj):
317-
import numpy as np
318-
if hasattr(np, '_CopyMode'):
319-
# Not present in older NumPys
320-
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
321-
COPY_TRUE = (True, np._CopyMode.ALWAYS)
322-
else:
323-
COPY_FALSE = (False,)
324-
COPY_TRUE = (True,)
325-
else:
326-
COPY_FALSE = (False,)
327-
COPY_TRUE = (True,)
328-
if copy in COPY_FALSE and namespace != "dask.array":
329-
# copy=False is not yet implemented in xp.asarray
330-
raise NotImplementedError("copy=False is not yet implemented")
331-
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
332-
if dtype is not None and obj.dtype != dtype:
333-
copy = True
334-
if copy in COPY_TRUE:
335-
return xp.array(obj, copy=True, dtype=dtype)
336-
return obj
337-
elif namespace == "dask.array":
338-
if copy in COPY_TRUE:
339-
if dtype is None:
340-
return obj.copy()
341-
# Go through numpy, since dask copy is no-op by default
342-
import numpy as np
343-
obj = np.array(obj, dtype=dtype, copy=True)
344-
return xp.array(obj, dtype=dtype)
345-
else:
346-
import dask.array as da
347-
import numpy as np
348-
if not isinstance(obj, da.Array):
349-
obj = np.asarray(obj, dtype=dtype)
350-
return da.from_array(obj)
351-
return obj
352-
353-
return xp.asarray(obj, dtype=dtype, **kwargs)
354-
355271
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
356272
def reshape(x: ndarray,
357273
/,

array_api_compat/cupy/_aliases.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from __future__ import annotations
22

3-
from functools import partial
4-
53
import cupy as cp
64

75
from ..common import _aliases
86
from .._internal import get_xp
97

10-
asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
11-
asarray.__doc__ = _aliases._asarray.__doc__
12-
del partial
8+
from typing import TYPE_CHECKING
9+
if TYPE_CHECKING:
10+
from typing import Optional, Union
11+
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1312

1413
bool = cp.bool_
1514

@@ -62,6 +61,52 @@
6261
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6362
tensordot = get_xp(cp)(_aliases.tensordot)
6463

64+
_copy_default = object()
65+
66+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
67+
def asarray(
68+
obj: Union[
69+
ndarray,
70+
bool,
71+
int,
72+
float,
73+
NestedSequence[bool | int | float],
74+
SupportsBufferProtocol,
75+
],
76+
/,
77+
*,
78+
dtype: Optional[Dtype] = None,
79+
device: Optional[Device] = None,
80+
copy: Optional[bool] = _copy_default,
81+
**kwargs,
82+
) -> ndarray:
83+
"""
84+
Array API compatibility wrapper for asarray().
85+
86+
See the corresponding documentation in the array library and/or the array API
87+
specification for more details.
88+
"""
89+
with cp.cuda.Device(device):
90+
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
91+
# in asarray in numpy/_aliases.py.
92+
if copy is not _copy_default:
93+
# A future version of CuPy will change the meaning of copy=False
94+
# to mean no-copy. We don't know for certain what version it will
95+
# be yet, so to avoid breaking that version, we use a different
96+
# default value for copy so asarray(obj) with no copy kwarg will
97+
# always do the copy-if-needed behavior.
98+
99+
# This will still need to be updated to remove the
100+
# NotImplementedError for copy=False, but at least this won't
101+
# break the default or existing behavior.
102+
if copy is None:
103+
copy = False
104+
elif copy is False:
105+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
106+
kwargs['copy'] = copy
107+
108+
return cp.array(obj, dtype=dtype, **kwargs)
109+
65110
# These functions are completely new here. If the library already has them
66111
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67112
if hasattr(cp, 'vecdot'):
@@ -73,7 +118,7 @@
73118
else:
74119
isdtype = get_xp(cp)(_aliases.isdtype)
75120

76-
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
121+
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
77122
'acosh', 'asin', 'asinh', 'atan', 'atan2',
78123
'atanh', 'bitwise_left_shift', 'bitwise_invert',
79124
'bitwise_right_shift', 'concat', 'pow']

array_api_compat/dask/array/_aliases.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
if TYPE_CHECKING:
3838
from typing import Optional, Union
3939

40-
from ...common._typing import Device, Dtype, Array
40+
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
4141

4242
import dask.array as da
4343

@@ -76,10 +76,6 @@ def _dask_arange(
7676
arange = get_xp(da)(_dask_arange)
7777
eye = get_xp(da)(_aliases.eye)
7878

79-
from functools import partial
80-
asarray = partial(_aliases._asarray, namespace='dask.array')
81-
asarray.__doc__ = _aliases._asarray.__doc__
82-
8379
linspace = get_xp(da)(_aliases.linspace)
8480
eye = get_xp(da)(_aliases.eye)
8581
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
@@ -113,6 +109,47 @@ def _dask_arange(
113109
matmul = get_xp(np)(_aliases.matmul)
114110
tensordot = get_xp(np)(_aliases.tensordot)
115111

112+
113+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
114+
def asarray(
115+
obj: Union[
116+
Array,
117+
bool,
118+
int,
119+
float,
120+
NestedSequence[bool | int | float],
121+
SupportsBufferProtocol,
122+
],
123+
/,
124+
*,
125+
dtype: Optional[Dtype] = None,
126+
device: Optional[Device] = None,
127+
copy: "Optional[Union[bool, np._CopyMode]]" = None,
128+
**kwargs,
129+
) -> Array:
130+
"""
131+
Array API compatibility wrapper for asarray().
132+
133+
See the corresponding documentation in the array library and/or the array API
134+
specification for more details.
135+
"""
136+
if copy is False:
137+
# copy=False is not yet implemented in dask
138+
raise NotImplementedError("copy=False is not yet implemented")
139+
elif copy is True:
140+
if isinstance(obj, da.Array) and dtype is None:
141+
return obj.copy()
142+
# Go through numpy, since dask copy is no-op by default
143+
obj = np.array(obj, dtype=dtype, copy=True)
144+
return da.array(obj, dtype=dtype)
145+
else:
146+
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
147+
obj = np.asarray(obj, dtype=dtype)
148+
return da.from_array(obj)
149+
return obj
150+
151+
return da.asarray(obj, dtype=dtype, **kwargs)
152+
116153
from dask.array import (
117154
# Element wise aliases
118155
arccos as acos,

array_api_compat/numpy/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,10 @@
2121

2222
from ..common._helpers import * # noqa: F403
2323

24+
try:
25+
# Used in asarray(). Not present in older versions.
26+
from numpy import _CopyMode # noqa: F401
27+
except ImportError:
28+
pass
29+
2430
__array_api_version__ = '2022.12'

0 commit comments

Comments
 (0)