Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c54d5e0

Browse files
authoredJun 26, 2024··
Merge pull request #35 from asmeurer/2023.12
Preliminary 2023.12 support
2 parents f489d51 + 6f8c07f commit c54d5e0

27 files changed

+1394
-232
lines changed
 

‎.github/workflows/array-api-tests.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on: [push, pull_request]
44

55
env:
66
PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200"
7+
API_VERSIONS: "2022.12 2023.12"
78

89
jobs:
910
array-api-tests:
@@ -45,9 +46,9 @@ jobs:
4546
- name: Run the array API testsuite
4647
env:
4748
ARRAY_API_TESTS_MODULE: array_api_strict
48-
# This enables the NEP 50 type promotion behavior (without it a lot of
49-
# tests fail in numpy 1.26 on bad scalar type promotion behavior)
50-
NPY_PROMOTION_STATE: weak
5149
run: |
52-
cd ${GITHUB_WORKSPACE}/array-api-tests
53-
pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS}
50+
# Parameterizing this in the CI matrix is wasteful. Just do a loop here.
51+
for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do
52+
cd ${GITHUB_WORKSPACE}/array-api-tests
53+
pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS}
54+
done

‎array_api_strict/__init__.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
1717
"""
1818

19+
__all__ = []
20+
1921
# Warning: __array_api_version__ could change globally with
2022
# set_array_api_strict_flags(). This should always be accessed as an
2123
# attribute, like xp.__array_api_version__, or using
2224
# array_api_strict.get_array_api_strict_flags()['api_version'].
2325
from ._flags import API_VERSION as __array_api_version__
2426

25-
__all__ = ["__array_api_version__"]
27+
__all__ += ["__array_api_version__"]
2628

2729
from ._constants import e, inf, nan, pi, newaxis
2830

@@ -137,7 +139,9 @@
137139
bitwise_right_shift,
138140
bitwise_xor,
139141
ceil,
142+
clip,
140143
conj,
144+
copysign,
141145
cos,
142146
cosh,
143147
divide,
@@ -148,6 +152,7 @@
148152
floor_divide,
149153
greater,
150154
greater_equal,
155+
hypot,
151156
imag,
152157
isfinite,
153158
isinf,
@@ -163,6 +168,8 @@
163168
logical_not,
164169
logical_or,
165170
logical_xor,
171+
maximum,
172+
minimum,
166173
multiply,
167174
negative,
168175
not_equal,
@@ -172,6 +179,7 @@
172179
remainder,
173180
round,
174181
sign,
182+
signbit,
175183
sin,
176184
sinh,
177185
square,
@@ -199,7 +207,9 @@
199207
"bitwise_right_shift",
200208
"bitwise_xor",
201209
"ceil",
210+
"clip",
202211
"conj",
212+
"copysign",
203213
"cos",
204214
"cosh",
205215
"divide",
@@ -210,6 +220,7 @@
210220
"floor_divide",
211221
"greater",
212222
"greater_equal",
223+
"hypot",
213224
"imag",
214225
"isfinite",
215226
"isinf",
@@ -225,6 +236,8 @@
225236
"logical_not",
226237
"logical_or",
227238
"logical_xor",
239+
"maximum",
240+
"minimum",
228241
"multiply",
229242
"negative",
230243
"not_equal",
@@ -234,6 +247,7 @@
234247
"remainder",
235248
"round",
236249
"sign",
250+
"signbit",
237251
"sin",
238252
"sinh",
239253
"square",
@@ -248,35 +262,36 @@
248262

249263
__all__ += ["take"]
250264

251-
# linalg is an extension in the array API spec, which is a sub-namespace. Only
252-
# a subset of functions in it are imported into the top-level namespace.
253-
from . import linalg
265+
from ._info import __array_namespace_info__
254266

255-
__all__ += ["linalg"]
267+
__all__ += [
268+
"__array_namespace_info__",
269+
]
256270

257271
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
258272

259273
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
260274

261-
from . import fft
262-
__all__ += ["fft"]
263-
264275
from ._manipulation_functions import (
265276
concat,
266277
expand_dims,
267278
flip,
279+
moveaxis,
268280
permute_dims,
281+
repeat,
269282
reshape,
270283
roll,
271284
squeeze,
272285
stack,
286+
tile,
287+
unstack,
273288
)
274289

275-
__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
290+
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]
276291

277-
from ._searching_functions import argmax, argmin, nonzero, where
292+
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
278293

279-
__all__ += ["argmax", "argmin", "nonzero", "where"]
294+
__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]
280295

281296
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
282297

@@ -286,9 +301,9 @@
286301

287302
__all__ += ["argsort", "sort"]
288303

289-
from ._statistical_functions import max, mean, min, prod, std, sum, var
304+
from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
290305

291-
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
306+
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
292307

293308
from ._utility_functions import all, any
294309

@@ -308,3 +323,22 @@
308323
from . import _version
309324
__version__ = _version.get_versions()['version']
310325
del _version
326+
327+
328+
# Extensions can be enabled or disabled dynamically. In order to make
329+
# "array_api_strict.linalg" give an AttributeError when it is disabled, we
330+
# use __getattr__. Note that linalg and fft are dynamically added and removed
331+
# from __all__ in set_array_api_strict_flags.
332+
333+
def __getattr__(name):
334+
if name in ['linalg', 'fft']:
335+
if name in get_array_api_strict_flags()['enabled_extensions']:
336+
if name == 'linalg':
337+
from . import _linalg
338+
return _linalg
339+
elif name == 'fft':
340+
from . import _fft
341+
return _fft
342+
else:
343+
raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict")
344+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

‎array_api_strict/_array_object.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __repr__(self):
5151

5252
CPU_DEVICE = _cpu_device()
5353

54+
_default = object()
55+
5456
class Array:
5557
"""
5658
n-d array object for the array API namespace.
@@ -437,7 +439,7 @@ def _validate_index(self, key):
437439
"Array API when the array is the sole index."
438440
)
439441
if not get_array_api_strict_flags()['boolean_indexing']:
440-
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict")
442+
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")
441443

442444
elif i.dtype in _integer_dtypes and i.ndim != 0:
443445
raise IndexError(
@@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex:
525527
res = self._array.__complex__()
526528
return res
527529

528-
def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
530+
def __dlpack__(
531+
self: Array,
532+
/,
533+
*,
534+
stream: Optional[Union[int, Any]] = None,
535+
max_version: Optional[tuple[int, int]] = _default,
536+
dl_device: Optional[tuple[IntEnum, int]] = _default,
537+
copy: Optional[bool] = _default,
538+
) -> PyCapsule:
529539
"""
530540
Performs the operation __dlpack__.
531541
"""
542+
if get_array_api_strict_flags()['api_version'] < '2023.12':
543+
if max_version is not _default:
544+
raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API")
545+
if dl_device is not _default:
546+
raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API")
547+
if copy is not _default:
548+
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
549+
550+
# Going to wait for upstream numpy support
551+
if max_version not in [_default, None]:
552+
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
553+
if dl_device not in [_default, None]:
554+
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
555+
if copy not in [_default, None]:
556+
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
557+
532558
return self._array.__dlpack__(stream=stream)
533559

534560
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
@@ -1142,7 +1168,7 @@ def device(self) -> Device:
11421168
# Note: mT is new in array API spec (see matrix_transpose)
11431169
@property
11441170
def mT(self) -> Array:
1145-
from .linalg import matrix_transpose
1171+
from ._linear_algebra_functions import matrix_transpose
11461172
return matrix_transpose(self)
11471173

11481174
@property

‎array_api_strict/_creation_functions.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
SupportsBufferProtocol,
1313
)
1414
from ._dtypes import _DType, _all_dtypes
15+
from ._flags import get_array_api_strict_flags
1516

1617
import numpy as np
1718

1819

1920
def _check_valid_dtype(dtype):
2021
# Note: Only spelling dtypes as the dtype objects is supported.
2122
if dtype not in (None,) + _all_dtypes:
22-
raise ValueError("dtype must be one of the supported dtypes")
23+
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")
2324

2425
def _supports_buffer_protocol(obj):
2526
try:
@@ -28,6 +29,14 @@ def _supports_buffer_protocol(obj):
2829
return False
2930
return True
3031

32+
def _check_device(device):
33+
# _array_object imports in this file are inside the functions to avoid
34+
# circular imports
35+
from ._array_object import CPU_DEVICE
36+
37+
if device not in [CPU_DEVICE, None]:
38+
raise ValueError(f"Unsupported device {device!r}")
39+
3140
def asarray(
3241
obj: Union[
3342
Array,
@@ -48,16 +57,13 @@ def asarray(
4857
4958
See its docstring for more information.
5059
"""
51-
# _array_object imports in this file are inside the functions to avoid
52-
# circular imports
53-
from ._array_object import Array, CPU_DEVICE
60+
from ._array_object import Array
5461

5562
_check_valid_dtype(dtype)
5663
_np_dtype = None
5764
if dtype is not None:
5865
_np_dtype = dtype._np_dtype
59-
if device not in [CPU_DEVICE, None]:
60-
raise ValueError(f"Unsupported device {device!r}")
66+
_check_device(device)
6167

6268
if np.__version__[0] < '2':
6369
if copy is False:
@@ -106,11 +112,11 @@ def arange(
106112
107113
See its docstring for more information.
108114
"""
109-
from ._array_object import Array, CPU_DEVICE
115+
from ._array_object import Array
110116

111117
_check_valid_dtype(dtype)
112-
if device not in [CPU_DEVICE, None]:
113-
raise ValueError(f"Unsupported device {device!r}")
118+
_check_device(device)
119+
114120
if dtype is not None:
115121
dtype = dtype._np_dtype
116122
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
@@ -127,11 +133,11 @@ def empty(
127133
128134
See its docstring for more information.
129135
"""
130-
from ._array_object import Array, CPU_DEVICE
136+
from ._array_object import Array
131137

132138
_check_valid_dtype(dtype)
133-
if device not in [CPU_DEVICE, None]:
134-
raise ValueError(f"Unsupported device {device!r}")
139+
_check_device(device)
140+
135141
if dtype is not None:
136142
dtype = dtype._np_dtype
137143
return Array._new(np.empty(shape, dtype=dtype))
@@ -145,11 +151,11 @@ def empty_like(
145151
146152
See its docstring for more information.
147153
"""
148-
from ._array_object import Array, CPU_DEVICE
154+
from ._array_object import Array
149155

150156
_check_valid_dtype(dtype)
151-
if device not in [CPU_DEVICE, None]:
152-
raise ValueError(f"Unsupported device {device!r}")
157+
_check_device(device)
158+
153159
if dtype is not None:
154160
dtype = dtype._np_dtype
155161
return Array._new(np.empty_like(x._array, dtype=dtype))
@@ -169,19 +175,39 @@ def eye(
169175
170176
See its docstring for more information.
171177
"""
172-
from ._array_object import Array, CPU_DEVICE
178+
from ._array_object import Array
173179

174180
_check_valid_dtype(dtype)
175-
if device not in [CPU_DEVICE, None]:
176-
raise ValueError(f"Unsupported device {device!r}")
181+
_check_device(device)
182+
177183
if dtype is not None:
178184
dtype = dtype._np_dtype
179185
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
180186

181187

182-
def from_dlpack(x: object, /) -> Array:
188+
_default = object()
189+
190+
def from_dlpack(
191+
x: object,
192+
/,
193+
*,
194+
device: Optional[Device] = _default,
195+
copy: Optional[bool] = _default,
196+
) -> Array:
183197
from ._array_object import Array
184198

199+
if get_array_api_strict_flags()['api_version'] < '2023.12':
200+
if device is not _default:
201+
raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API")
202+
if copy is not _default:
203+
raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API")
204+
205+
# Going to wait for upstream numpy support
206+
if device is not _default:
207+
_check_device(device)
208+
if copy not in [_default, None]:
209+
raise NotImplementedError("The copy argument to from_dlpack is not yet implemented")
210+
185211
return Array._new(np.from_dlpack(x))
186212

187213

@@ -197,11 +223,11 @@ def full(
197223
198224
See its docstring for more information.
199225
"""
200-
from ._array_object import Array, CPU_DEVICE
226+
from ._array_object import Array
201227

202228
_check_valid_dtype(dtype)
203-
if device not in [CPU_DEVICE, None]:
204-
raise ValueError(f"Unsupported device {device!r}")
229+
_check_device(device)
230+
205231
if isinstance(fill_value, Array) and fill_value.ndim == 0:
206232
fill_value = fill_value._array
207233
if dtype is not None:
@@ -227,11 +253,11 @@ def full_like(
227253
228254
See its docstring for more information.
229255
"""
230-
from ._array_object import Array, CPU_DEVICE
256+
from ._array_object import Array
231257

232258
_check_valid_dtype(dtype)
233-
if device not in [CPU_DEVICE, None]:
234-
raise ValueError(f"Unsupported device {device!r}")
259+
_check_device(device)
260+
235261
if dtype is not None:
236262
dtype = dtype._np_dtype
237263
res = np.full_like(x._array, fill_value, dtype=dtype)
@@ -257,11 +283,11 @@ def linspace(
257283
258284
See its docstring for more information.
259285
"""
260-
from ._array_object import Array, CPU_DEVICE
286+
from ._array_object import Array
261287

262288
_check_valid_dtype(dtype)
263-
if device not in [CPU_DEVICE, None]:
264-
raise ValueError(f"Unsupported device {device!r}")
289+
_check_device(device)
290+
265291
if dtype is not None:
266292
dtype = dtype._np_dtype
267293
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
@@ -298,11 +324,11 @@ def ones(
298324
299325
See its docstring for more information.
300326
"""
301-
from ._array_object import Array, CPU_DEVICE
327+
from ._array_object import Array
302328

303329
_check_valid_dtype(dtype)
304-
if device not in [CPU_DEVICE, None]:
305-
raise ValueError(f"Unsupported device {device!r}")
330+
_check_device(device)
331+
306332
if dtype is not None:
307333
dtype = dtype._np_dtype
308334
return Array._new(np.ones(shape, dtype=dtype))
@@ -316,11 +342,11 @@ def ones_like(
316342
317343
See its docstring for more information.
318344
"""
319-
from ._array_object import Array, CPU_DEVICE
345+
from ._array_object import Array
320346

321347
_check_valid_dtype(dtype)
322-
if device not in [CPU_DEVICE, None]:
323-
raise ValueError(f"Unsupported device {device!r}")
348+
_check_device(device)
349+
324350
if dtype is not None:
325351
dtype = dtype._np_dtype
326352
return Array._new(np.ones_like(x._array, dtype=dtype))
@@ -365,11 +391,11 @@ def zeros(
365391
366392
See its docstring for more information.
367393
"""
368-
from ._array_object import Array, CPU_DEVICE
394+
from ._array_object import Array
369395

370396
_check_valid_dtype(dtype)
371-
if device not in [CPU_DEVICE, None]:
372-
raise ValueError(f"Unsupported device {device!r}")
397+
_check_device(device)
398+
373399
if dtype is not None:
374400
dtype = dtype._np_dtype
375401
return Array._new(np.zeros(shape, dtype=dtype))
@@ -383,11 +409,11 @@ def zeros_like(
383409
384410
See its docstring for more information.
385411
"""
386-
from ._array_object import Array, CPU_DEVICE
412+
from ._array_object import Array
387413

388414
_check_valid_dtype(dtype)
389-
if device not in [CPU_DEVICE, None]:
390-
raise ValueError(f"Unsupported device {device!r}")
415+
_check_device(device)
416+
391417
if dtype is not None:
392418
dtype = dtype._np_dtype
393419
return Array._new(np.zeros_like(x._array, dtype=dtype))

‎array_api_strict/_data_type_functions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4+
from ._creation_functions import _check_device
45
from ._dtypes import (
56
_DType,
67
_all_dtypes,
@@ -13,19 +14,30 @@
1314
_numeric_dtypes,
1415
_result_type,
1516
)
17+
from ._flags import get_array_api_strict_flags
1618

1719
from dataclasses import dataclass
1820
from typing import TYPE_CHECKING
1921

2022
if TYPE_CHECKING:
21-
from typing import List, Tuple, Union
22-
from ._typing import Dtype
23+
from typing import List, Tuple, Union, Optional
24+
from ._typing import Dtype, Device
2325

2426
import numpy as np
2527

28+
# Use to emulate the asarray(device) argument not existing in 2022.12
29+
_default = object()
2630

2731
# Note: astype is a function, not an array method as in NumPy.
28-
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
32+
def astype(
33+
x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default
34+
) -> Array:
35+
if device is not _default:
36+
if get_array_api_strict_flags()['api_version'] >= '2023.12':
37+
_check_device(device)
38+
else:
39+
raise TypeError("The device argument to astype requires at least version 2023.12 of the array API")
40+
2941
if not copy and dtype == x.dtype:
3042
return x
3143
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy))

‎array_api_strict/_elementwise_functions.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
_result_type,
1313
)
1414
from ._array_object import Array
15+
from ._flags import requires_api_version
16+
from ._creation_functions import asarray
17+
18+
from typing import Optional, Union
1519

1620
import numpy as np
1721

@@ -240,6 +244,70 @@ def ceil(x: Array, /) -> Array:
240244
return x
241245
return Array._new(np.ceil(x._array))
242246

247+
# WARNING: This function is not yet tested by the array-api-tests test suite.
248+
249+
# Note: min and max argument names are different and not optional in numpy.
250+
@requires_api_version('2023.12')
251+
def clip(
252+
x: Array,
253+
/,
254+
min: Optional[Union[int, float, Array]] = None,
255+
max: Optional[Union[int, float, Array]] = None,
256+
) -> Array:
257+
"""
258+
Array API compatible wrapper for :py:func:`np.clip <numpy.clip>`.
259+
260+
See its docstring for more information.
261+
"""
262+
if (x.dtype not in _real_numeric_dtypes
263+
or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes
264+
or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes):
265+
raise TypeError("Only real numeric dtypes are allowed in clip")
266+
if not isinstance(min, (int, float, Array, type(None))):
267+
raise TypeError("min must be an None, int, float, or an array")
268+
if not isinstance(max, (int, float, Array, type(None))):
269+
raise TypeError("max must be an None, int, float, or an array")
270+
271+
# Mixed dtype kinds is implementation defined
272+
if (x.dtype in _integer_dtypes
273+
and (isinstance(min, float) or
274+
isinstance(min, Array) and min.dtype in _real_floating_dtypes)):
275+
raise TypeError("min must be integral when x is integral")
276+
if (x.dtype in _integer_dtypes
277+
and (isinstance(max, float) or
278+
isinstance(max, Array) and max.dtype in _real_floating_dtypes)):
279+
raise TypeError("max must be integral when x is integral")
280+
if (x.dtype in _real_floating_dtypes
281+
and (isinstance(min, int) or
282+
isinstance(min, Array) and min.dtype in _integer_dtypes)):
283+
raise TypeError("min must be floating-point when x is floating-point")
284+
if (x.dtype in _real_floating_dtypes
285+
and (isinstance(max, int) or
286+
isinstance(max, Array) and max.dtype in _integer_dtypes)):
287+
raise TypeError("max must be floating-point when x is floating-point")
288+
289+
if min is max is None:
290+
# Note: NumPy disallows min = max = None
291+
return x
292+
293+
# Normalize to make the below logic simpler
294+
if min is not None:
295+
min = asarray(min)._array
296+
if max is not None:
297+
max = asarray(max)._array
298+
299+
# min > max is implementation defined
300+
if min is not None and max is not None and np.any(min > max):
301+
raise ValueError("min must be less than or equal to max")
302+
303+
result = np.clip(x._array, min, max)
304+
# Note: NumPy applies type promotion, but the standard specifies the
305+
# return dtype should be the same as x
306+
if result.dtype != x.dtype._np_dtype:
307+
# TODO: I'm not completely sure this always gives the correct thing
308+
# for integer dtypes. See https://github.com/numpy/numpy/issues/24976
309+
result = result.astype(x.dtype._np_dtype)
310+
return Array._new(result)
243311

244312
def conj(x: Array, /) -> Array:
245313
"""
@@ -251,6 +319,19 @@ def conj(x: Array, /) -> Array:
251319
raise TypeError("Only complex floating-point dtypes are allowed in conj")
252320
return Array._new(np.conj(x))
253321

322+
@requires_api_version('2023.12')
323+
def copysign(x1: Array, x2: Array, /) -> Array:
324+
"""
325+
Array API compatible wrapper for :py:func:`np.copysign <numpy.copysign>`.
326+
327+
See its docstring for more information.
328+
"""
329+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
330+
raise TypeError("Only real numeric dtypes are allowed in copysign")
331+
# Call result type here just to raise on disallowed type combinations
332+
_result_type(x1.dtype, x2.dtype)
333+
x1, x2 = Array._normalize_two_args(x1, x2)
334+
return Array._new(np.copysign(x1._array, x2._array))
254335

255336
def cos(x: Array, /) -> Array:
256337
"""
@@ -377,6 +458,19 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
377458
x1, x2 = Array._normalize_two_args(x1, x2)
378459
return Array._new(np.greater_equal(x1._array, x2._array))
379460

461+
@requires_api_version('2023.12')
462+
def hypot(x1: Array, x2: Array, /) -> Array:
463+
"""
464+
Array API compatible wrapper for :py:func:`np.hypot <numpy.hypot>`.
465+
466+
See its docstring for more information.
467+
"""
468+
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
469+
raise TypeError("Only real floating-point dtypes are allowed in hypot")
470+
# Call result type here just to raise on disallowed type combinations
471+
_result_type(x1.dtype, x2.dtype)
472+
x1, x2 = Array._normalize_two_args(x1, x2)
473+
return Array._new(np.hypot(x1._array, x2._array))
380474

381475
def imag(x: Array, /) -> Array:
382476
"""
@@ -560,6 +654,35 @@ def logical_xor(x1: Array, x2: Array, /) -> Array:
560654
x1, x2 = Array._normalize_two_args(x1, x2)
561655
return Array._new(np.logical_xor(x1._array, x2._array))
562656

657+
@requires_api_version('2023.12')
658+
def maximum(x1: Array, x2: Array, /) -> Array:
659+
"""
660+
Array API compatible wrapper for :py:func:`np.maximum <numpy.maximum>`.
661+
662+
See its docstring for more information.
663+
"""
664+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
665+
raise TypeError("Only real numeric dtypes are allowed in maximum")
666+
# Call result type here just to raise on disallowed type combinations
667+
_result_type(x1.dtype, x2.dtype)
668+
x1, x2 = Array._normalize_two_args(x1, x2)
669+
# TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error
670+
# in that case?
671+
return Array._new(np.maximum(x1._array, x2._array))
672+
673+
@requires_api_version('2023.12')
674+
def minimum(x1: Array, x2: Array, /) -> Array:
675+
"""
676+
Array API compatible wrapper for :py:func:`np.minimum <numpy.minimum>`.
677+
678+
See its docstring for more information.
679+
"""
680+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
681+
raise TypeError("Only real numeric dtypes are allowed in minimum")
682+
# Call result type here just to raise on disallowed type combinations
683+
_result_type(x1.dtype, x2.dtype)
684+
x1, x2 = Array._normalize_two_args(x1, x2)
685+
return Array._new(np.minimum(x1._array, x2._array))
563686

564687
def multiply(x1: Array, x2: Array, /) -> Array:
565688
"""
@@ -671,6 +794,18 @@ def sign(x: Array, /) -> Array:
671794
return Array._new(np.sign(x._array))
672795

673796

797+
@requires_api_version('2023.12')
798+
def signbit(x: Array, /) -> Array:
799+
"""
800+
Array API compatible wrapper for :py:func:`np.signbit <numpy.signbit>`.
801+
802+
See its docstring for more information.
803+
"""
804+
if x.dtype not in _real_floating_dtypes:
805+
raise TypeError("Only real floating-point dtypes are allowed in signbit")
806+
return Array._new(np.signbit(x._array))
807+
808+
674809
def sin(x: Array, /) -> Array:
675810
"""
676811
Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`.
File renamed without changes.

‎array_api_strict/_flags.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
supported_versions = (
2222
"2021.12",
2323
"2022.12",
24+
"2023.12",
2425
)
2526

2627
API_VERSION = default_version = "2022.12"
@@ -70,6 +71,8 @@ def set_array_api_strict_flags(
7071
Note that 2021.12 is supported, but currently gives the same thing as
7172
2022.12 (except that the fft extension will be disabled).
7273
74+
2023.12 support is experimental. Some features in 2023.12 may still be
75+
missing, and it hasn't been fully tested.
7376
7477
- `boolean_indexing`: Whether indexing by a boolean array is supported.
7578
Note that although boolean array indexing does result in data-dependent
@@ -86,9 +89,9 @@ def set_array_api_strict_flags(
8689
The functions that make use of data-dependent shapes, and are therefore
8790
disabled by setting this flag to False are
8891
89-
- `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`.
90-
- `nonzero`
91-
- `repeat` when the `repeats` argument is an array (requires 2023.12
92+
- `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`.
93+
- `nonzero()`
94+
- `repeat()` when the `repeats` argument is an array (requires 2023.12
9295
version of the standard)
9396
9497
Note that while boolean indexing is also data-dependent, it is
@@ -133,7 +136,9 @@ def set_array_api_strict_flags(
133136
if api_version not in supported_versions:
134137
raise ValueError(f"Unsupported standard version {api_version!r}")
135138
if api_version == "2021.12":
136-
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
139+
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2)
140+
if api_version == "2023.12":
141+
warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2)
137142
API_VERSION = api_version
138143
array_api_strict.__array_api_version__ = API_VERSION
139144

@@ -154,7 +159,11 @@ def set_array_api_strict_flags(
154159
)
155160
ENABLED_EXTENSIONS = tuple(enabled_extensions)
156161
else:
157-
ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION])
162+
ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION])
163+
164+
array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) |
165+
set(array_api_strict.__all__) -
166+
set(default_extensions))
158167

159168
# We have to do this separately or it won't get added as the docstring
160169
set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format(
@@ -172,6 +181,14 @@ def get_array_api_strict_flags():
172181
This function is **not** part of the array API standard. It only exists
173182
in array-api-strict.
174183
184+
.. note::
185+
186+
The `inspection API
187+
<https://data-apis.org/array-api/latest/API_specification/inspection.html>`__
188+
provides a portable way to access most of this information. However, it
189+
is only present in standard versions starting with 2023.12. The array
190+
API version can be accessed portably using `xp.__array_api_version__`.
191+
175192
Returns
176193
-------
177194
dict
@@ -280,29 +297,51 @@ def __exit__(self, exc_type, exc_value, traceback):
280297

281298
# Private functions
282299

300+
ENVIRONMENT_VARIABLES = [
301+
"ARRAY_API_STRICT_API_VERSION",
302+
"ARRAY_API_STRICT_BOOLEAN_INDEXING",
303+
"ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES",
304+
"ARRAY_API_STRICT_ENABLED_EXTENSIONS",
305+
]
306+
283307
def set_flags_from_environment():
308+
kwargs = {}
284309
if "ARRAY_API_STRICT_API_VERSION" in os.environ:
285-
set_array_api_strict_flags(
286-
api_version=os.environ["ARRAY_API_STRICT_API_VERSION"]
287-
)
310+
kwargs["api_version"] = os.environ["ARRAY_API_STRICT_API_VERSION"]
288311

289312
if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ:
290-
set_array_api_strict_flags(
291-
boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true"
292-
)
313+
kwargs["boolean_indexing"] = os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true"
293314

294315
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:
295-
set_array_api_strict_flags(
296-
data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"
297-
)
316+
kwargs["data_dependent_shapes"] = os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"
298317

299318
if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ:
300-
set_array_api_strict_flags(
301-
enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",")
302-
)
319+
enabled_extensions = os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",")
320+
if enabled_extensions == [""]:
321+
enabled_extensions = []
322+
kwargs["enabled_extensions"] = enabled_extensions
323+
324+
# Called unconditionally because it is needed at first import to add
325+
# linalg and fft to __all__
326+
set_array_api_strict_flags(**kwargs)
303327

304328
set_flags_from_environment()
305329

330+
# Decorators
331+
332+
def requires_api_version(version):
333+
def decorator(func):
334+
@functools.wraps(func)
335+
def wrapper(*args, **kwargs):
336+
if version > API_VERSION:
337+
raise RuntimeError(
338+
f"The function {func.__name__} requires API version {version} or later, "
339+
f"but the current API version for array-api-strict is {API_VERSION}"
340+
)
341+
return func(*args, **kwargs)
342+
return wrapper
343+
return decorator
344+
306345
def requires_data_dependent_shapes(func):
307346
@functools.wraps(func)
308347
def wrapper(*args, **kwargs):

‎array_api_strict/_info.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from typing import Optional, Union, Tuple, List
7+
from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info
8+
9+
from ._array_object import CPU_DEVICE
10+
from ._flags import get_array_api_strict_flags, requires_api_version
11+
from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128
12+
13+
@requires_api_version('2023.12')
14+
def __array_namespace_info__() -> Info:
15+
import array_api_strict._info
16+
return array_api_strict._info
17+
18+
@requires_api_version('2023.12')
19+
def capabilities() -> Capabilities:
20+
flags = get_array_api_strict_flags()
21+
return {"boolean indexing": flags['boolean_indexing'],
22+
"data-dependent shapes": flags['data_dependent_shapes'],
23+
}
24+
25+
@requires_api_version('2023.12')
26+
def default_device() -> device:
27+
return CPU_DEVICE
28+
29+
@requires_api_version('2023.12')
30+
def default_dtypes(
31+
*,
32+
device: Optional[device] = None,
33+
) -> DefaultDataTypes:
34+
return {
35+
"real floating": float64,
36+
"complex floating": complex128,
37+
"integral": int64,
38+
"indexing": int64,
39+
}
40+
41+
@requires_api_version('2023.12')
42+
def dtypes(
43+
*,
44+
device: Optional[device] = None,
45+
kind: Optional[Union[str, Tuple[str, ...]]] = None,
46+
) -> DataTypes:
47+
if kind is None:
48+
return {
49+
"bool": bool,
50+
"int8": int8,
51+
"int16": int16,
52+
"int32": int32,
53+
"int64": int64,
54+
"uint8": uint8,
55+
"uint16": uint16,
56+
"uint32": uint32,
57+
"uint64": uint64,
58+
"float32": float32,
59+
"float64": float64,
60+
"complex64": complex64,
61+
"complex128": complex128,
62+
}
63+
if kind == "bool":
64+
return {"bool": bool}
65+
if kind == "signed integer":
66+
return {
67+
"int8": int8,
68+
"int16": int16,
69+
"int32": int32,
70+
"int64": int64,
71+
}
72+
if kind == "unsigned integer":
73+
return {
74+
"uint8": uint8,
75+
"uint16": uint16,
76+
"uint32": uint32,
77+
"uint64": uint64,
78+
}
79+
if kind == "integral":
80+
return {
81+
"int8": int8,
82+
"int16": int16,
83+
"int32": int32,
84+
"int64": int64,
85+
"uint8": uint8,
86+
"uint16": uint16,
87+
"uint32": uint32,
88+
"uint64": uint64,
89+
}
90+
if kind == "real floating":
91+
return {
92+
"float32": float32,
93+
"float64": float64,
94+
}
95+
if kind == "complex floating":
96+
return {
97+
"complex64": complex64,
98+
"complex128": complex128,
99+
}
100+
if kind == "numeric":
101+
return {
102+
"int8": int8,
103+
"int16": int16,
104+
"int32": int32,
105+
"int64": int64,
106+
"uint8": uint8,
107+
"uint16": uint16,
108+
"uint32": uint32,
109+
"uint64": uint64,
110+
"float32": float32,
111+
"float64": float64,
112+
"complex64": complex64,
113+
"complex128": complex128,
114+
}
115+
if isinstance(kind, tuple):
116+
res = {}
117+
for k in kind:
118+
res.update(dtypes(kind=k))
119+
return res
120+
raise ValueError(f"unsupported kind: {kind!r}")
121+
122+
@requires_api_version('2023.12')
123+
def devices() -> List[device]:
124+
return [CPU_DEVICE]
125+
126+
__all__ = [
127+
"capabilities",
128+
"default_device",
129+
"default_dtypes",
130+
"devices",
131+
"dtypes",
132+
]

‎array_api_strict/linalg.py renamed to ‎array_api_strict/_linalg.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._manipulation_functions import reshape
1212
from ._elementwise_functions import conj
1313
from ._array_object import Array
14-
from ._flags import requires_extension
14+
from ._flags import requires_extension, get_array_api_strict_flags
1515

1616
try:
1717
from numpy._core.numeric import normalize_axis_tuple
@@ -80,6 +80,17 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
8080
# Note: this is different from np.cross(), which allows dimension 2
8181
if x1.shape[axis] != 3:
8282
raise ValueError('cross() dimension must equal 3')
83+
84+
if get_array_api_strict_flags()['api_version'] >= '2023.12':
85+
if axis >= 0:
86+
raise ValueError("axis must be negative in cross")
87+
elif axis < min(-1, -x1.ndim, -x2.ndim):
88+
raise ValueError("axis is out of bounds for x1 and x2")
89+
90+
# Prior to 2023.12, there was ambiguity in the standard about whether
91+
# positive axis applied before or after broadcasting. NumPy applies
92+
# the axis before broadcasting. Since that behavior is what has always
93+
# been implemented here, we keep it for backwards compatibility.
8394
return Array._new(np.cross(x1._array, x2._array, axis=axis))
8495

8596
@requires_extension('linalg')
@@ -377,10 +388,11 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr
377388
# Note: trace() works the same as sum() and prod() (see
378389
# _statistical_functions.py)
379390
if dtype is None:
380-
if x.dtype == float32:
381-
dtype = np.float64
382-
elif x.dtype == complex64:
383-
dtype = np.complex128
391+
if get_array_api_strict_flags()['api_version'] < '2023.12':
392+
if x.dtype == float32:
393+
dtype = np.float64
394+
elif x.dtype == complex64:
395+
dtype = np.complex128
384396
else:
385397
dtype = dtype._np_dtype
386398
# Note: trace always operates on the last two axes, whereas np.trace

‎array_api_strict/_linear_algebra_functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from __future__ import annotations
99

1010
from ._dtypes import _numeric_dtypes
11-
1211
from ._array_object import Array
12+
from ._flags import get_array_api_strict_flags
1313

1414
from typing import TYPE_CHECKING
1515
if TYPE_CHECKING:
@@ -54,6 +54,19 @@ def matrix_transpose(x: Array, /) -> Array:
5454
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
5555
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
5656
raise TypeError('Only numeric dtypes are allowed in vecdot')
57+
58+
if get_array_api_strict_flags()['api_version'] >= '2023.12':
59+
if axis >= 0:
60+
raise ValueError("axis must be negative in vecdot")
61+
elif axis < min(-1, -x1.ndim, -x2.ndim):
62+
raise ValueError("axis is out of bounds for x1 and x2")
63+
64+
# In versions of the standard prior to 2023.12, vecdot applied axis after
65+
# broadcasting. This is different from applying it before broadcasting
66+
# when axis is nonnegative. The below code keeps this behavior for
67+
# 2022.12, primarily for backwards compatibility. Note that the behavior
68+
# is unambiguous when axis is negative, so the below code should work
69+
# correctly in that case regardless of which version is used.
5770
ndim = max(x1.ndim, x2.ndim)
5871
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
5972
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)

‎array_api_strict/_manipulation_functions.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4+
from ._creation_functions import asarray
45
from ._data_type_functions import result_type
6+
from ._dtypes import _integer_dtypes
7+
from ._flags import requires_api_version, get_array_api_strict_flags
58

69
from typing import TYPE_CHECKING
710

@@ -43,6 +46,19 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
4346
"""
4447
return Array._new(np.flip(x._array, axis=axis))
4548

49+
@requires_api_version('2023.12')
50+
def moveaxis(
51+
x: Array,
52+
source: Union[int, Tuple[int, ...]],
53+
destination: Union[int, Tuple[int, ...]],
54+
/,
55+
) -> Array:
56+
"""
57+
Array API compatible wrapper for :py:func:`np.moveaxis <numpy.moveaxis>`.
58+
59+
See its docstring for more information.
60+
"""
61+
return Array._new(np.moveaxis(x._array, source, destination))
4662

4763
# Note: The function name is different here (see also matrix_transpose).
4864
# Unlike transpose(), the axes argument is required.
@@ -54,6 +70,31 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
5470
"""
5571
return Array._new(np.transpose(x._array, axes))
5672

73+
@requires_api_version('2023.12')
74+
def repeat(
75+
x: Array,
76+
repeats: Union[int, Array],
77+
/,
78+
*,
79+
axis: Optional[int] = None,
80+
) -> Array:
81+
"""
82+
Array API compatible wrapper for :py:func:`np.repeat <numpy.repeat>`.
83+
84+
See its docstring for more information.
85+
"""
86+
if isinstance(repeats, Array):
87+
data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes']
88+
if not data_dependent_shapes:
89+
raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
90+
if repeats.dtype not in _integer_dtypes:
91+
raise TypeError("The repeats array must have an integer dtype")
92+
elif isinstance(repeats, int):
93+
repeats = asarray(repeats)
94+
else:
95+
raise TypeError("repeats must be an int or array")
96+
97+
return Array._new(np.repeat(x._array, repeats, axis=axis))
5798

5899
# Note: the optional argument is called 'shape', not 'newshape'
59100
def reshape(x: Array,
@@ -113,3 +154,28 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
113154
result_type(*arrays)
114155
arrays = tuple(a._array for a in arrays)
115156
return Array._new(np.stack(arrays, axis=axis))
157+
158+
159+
@requires_api_version('2023.12')
160+
def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
161+
"""
162+
Array API compatible wrapper for :py:func:`np.tile <numpy.tile>`.
163+
164+
See its docstring for more information.
165+
"""
166+
# Note: NumPy allows repetitions to be an int or array
167+
if not isinstance(repetitions, tuple):
168+
raise TypeError("repetitions must be a tuple")
169+
return Array._new(np.tile(x._array, repetitions))
170+
171+
# Note: this function is new
172+
@requires_api_version('2023.12')
173+
def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]:
174+
if not (-x.ndim <= axis < x.ndim):
175+
raise ValueError("axis out of range")
176+
177+
if axis < 0:
178+
axis += x.ndim
179+
180+
slices = (slice(None),) * axis
181+
return tuple(x[slices + (i, ...)] for i in range(x.shape[axis]))

‎array_api_strict/_searching_functions.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from ._array_object import Array
44
from ._dtypes import _result_type, _real_numeric_dtypes
5-
from ._flags import requires_data_dependent_shapes
5+
from ._flags import requires_data_dependent_shapes, requires_api_version
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
from typing import Optional, Tuple
9+
from typing import Literal, Optional, Tuple
1010

1111
import numpy as np
1212

@@ -45,6 +45,28 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]:
4545
raise ValueError("nonzero is not allowed on 0-dimensional arrays")
4646
return tuple(Array._new(i) for i in np.nonzero(x._array))
4747

48+
@requires_api_version('2023.12')
49+
def searchsorted(
50+
x1: Array,
51+
x2: Array,
52+
/,
53+
*,
54+
side: Literal["left", "right"] = "left",
55+
sorter: Optional[Array] = None,
56+
) -> Array:
57+
"""
58+
Array API compatible wrapper for :py:func:`np.searchsorted <numpy.searchsorted>`.
59+
60+
See its docstring for more information.
61+
"""
62+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
63+
raise TypeError("Only real numeric dtypes are allowed in searchsorted")
64+
sorter = sorter._array if sorter is not None else None
65+
# TODO: The sort order of nans and signed zeros is implementation
66+
# dependent. Should we error/warn if they are present?
67+
68+
# x1 must be 1-D, but NumPy already requires this.
69+
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter))
4870

4971
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
5072
"""

‎array_api_strict/_statistical_functions.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
)
88
from ._array_object import Array
99
from ._dtypes import float32, complex64
10+
from ._flags import requires_api_version, get_array_api_strict_flags
11+
from ._creation_functions import zeros
12+
from ._manipulation_functions import concat
1013

1114
from typing import TYPE_CHECKING
1215

@@ -16,6 +19,32 @@
1619

1720
import numpy as np
1821

22+
@requires_api_version('2023.12')
23+
def cumulative_sum(
24+
x: Array,
25+
/,
26+
*,
27+
axis: Optional[int] = None,
28+
dtype: Optional[Dtype] = None,
29+
include_initial: bool = False,
30+
) -> Array:
31+
if x.dtype not in _numeric_dtypes:
32+
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
33+
dt = x.dtype if dtype is None else dtype
34+
if dtype is not None:
35+
dtype = dtype._np_dtype
36+
37+
# TODO: The standard is not clear about what should happen when x.ndim == 0.
38+
if axis is None:
39+
if x.ndim > 1:
40+
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
41+
axis = 0
42+
# np.cumsum does not support include_initial
43+
if include_initial:
44+
if axis < 0:
45+
axis += x.ndim
46+
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
47+
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype))
1948

2049
def max(
2150
x: Array,
@@ -63,14 +92,16 @@ def prod(
6392
) -> Array:
6493
if x.dtype not in _numeric_dtypes:
6594
raise TypeError("Only numeric dtypes are allowed in prod")
66-
# Note: sum() and prod() always upcast for dtype=None. `np.prod` does that
67-
# for integers, but not for float32 or complex64, so we need to
68-
# special-case it here
95+
6996
if dtype is None:
70-
if x.dtype == float32:
71-
dtype = np.float64
72-
elif x.dtype == complex64:
73-
dtype = np.complex128
97+
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
98+
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
99+
# NumPy (only upcast for integral dtypes).
100+
if get_array_api_strict_flags()['api_version'] < '2023.12':
101+
if x.dtype == float32:
102+
dtype = np.float64
103+
elif x.dtype == complex64:
104+
dtype = np.complex128
74105
else:
75106
dtype = dtype._np_dtype
76107
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
@@ -100,14 +131,16 @@ def sum(
100131
) -> Array:
101132
if x.dtype not in _numeric_dtypes:
102133
raise TypeError("Only numeric dtypes are allowed in sum")
103-
# Note: sum() and prod() always upcast for dtype=None. `np.sum` does that
104-
# for integers, but not for float32 or complex64, so we need to
105-
# special-case it here
134+
106135
if dtype is None:
107-
if x.dtype == float32:
108-
dtype = np.float64
109-
elif x.dtype == complex64:
110-
dtype = np.complex128
136+
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
137+
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
138+
# NumPy (only upcast for integral dtypes).
139+
if get_array_api_strict_flags()['api_version'] < '2023.12':
140+
if x.dtype == float32:
141+
dtype = np.float64
142+
elif x.dtype == complex64:
143+
dtype = np.complex128
111144
else:
112145
dtype = dtype._np_dtype
113146
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))

‎array_api_strict/_typing.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
from typing import (
2323
Any,
24+
ModuleType,
25+
TypedDict,
2426
TypeVar,
2527
Protocol,
2628
)
@@ -39,6 +41,8 @@ def __len__(self, /) -> int: ...
3941

4042
Dtype = _DType
4143

44+
Info = ModuleType
45+
4246
if sys.version_info >= (3, 12):
4347
from collections.abc import Buffer as SupportsBufferProtocol
4448
else:
@@ -48,3 +52,37 @@ def __len__(self, /) -> int: ...
4852

4953
class SupportsDLPack(Protocol):
5054
def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
55+
56+
Capabilities = TypedDict(
57+
"Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool}
58+
)
59+
60+
DefaultDataTypes = TypedDict(
61+
"DefaultDataTypes",
62+
{
63+
"real floating": Dtype,
64+
"complex floating": Dtype,
65+
"integral": Dtype,
66+
"indexing": Dtype,
67+
},
68+
)
69+
70+
DataTypes = TypedDict(
71+
"DataTypes",
72+
{
73+
"bool": Dtype,
74+
"float32": Dtype,
75+
"float64": Dtype,
76+
"complex64": Dtype,
77+
"complex128": Dtype,
78+
"int8": Dtype,
79+
"int16": Dtype,
80+
"int32": Dtype,
81+
"int64": Dtype,
82+
"uint8": Dtype,
83+
"uint16": Dtype,
84+
"uint32": Dtype,
85+
"uint64": Dtype,
86+
},
87+
total=False,
88+
)

‎array_api_strict/tests/conftest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
from .._flags import reset_array_api_strict_flags
1+
import os
2+
3+
from .._flags import reset_array_api_strict_flags, ENVIRONMENT_VARIABLES
24

35
import pytest
46

7+
def pytest_configure(config):
8+
for env_var in ENVIRONMENT_VARIABLES:
9+
if env_var in os.environ:
10+
pytest.exit(f"ERROR: {env_var} is set. array-api-strict environment variables must not be set when the tests are run.")
11+
512
@pytest.fixture(autouse=True)
613
def reset_flags():
714
reset_array_api_strict_flags()

‎array_api_strict/tests/test_array_object.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
uint64,
2424
bool as bool_,
2525
)
26+
from .._flags import set_array_api_strict_flags
27+
2628
import array_api_strict
2729

2830
def test_validate_index():
@@ -410,13 +412,46 @@ def test_array_namespace():
410412
assert a.__array_namespace__(api_version="2022.12") is array_api_strict
411413
assert array_api_strict.__array_api_version__ == "2022.12"
412414

415+
with pytest.warns(UserWarning):
416+
assert a.__array_namespace__(api_version="2023.12") is array_api_strict
417+
assert array_api_strict.__array_api_version__ == "2023.12"
418+
413419
with pytest.warns(UserWarning):
414420
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
415421
assert array_api_strict.__array_api_version__ == "2021.12"
416422

417423
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
418-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12"))
424+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
419425

420426
def test_no_iter():
421427
pytest.raises(TypeError, lambda: iter(ones(3)))
422428
pytest.raises(TypeError, lambda: iter(ones((3, 3))))
429+
430+
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
431+
def dlpack_2023_12(api_version):
432+
if api_version != '2022.12':
433+
with pytest.warns(UserWarning):
434+
set_array_api_strict_flags(api_version=api_version)
435+
else:
436+
set_array_api_strict_flags(api_version=api_version)
437+
438+
a = asarray([1, 2, 3], dtype=int8)
439+
# Never an error
440+
a.__dlpack__()
441+
442+
443+
exception = NotImplementedError if api_version >= '2023.12' else ValueError
444+
pytest.raises(exception, lambda:
445+
a.__dlpack__(dl_device=CPU_DEVICE))
446+
pytest.raises(exception, lambda:
447+
a.__dlpack__(dl_device=None))
448+
pytest.raises(exception, lambda:
449+
a.__dlpack__(max_version=(1, 0)))
450+
pytest.raises(exception, lambda:
451+
a.__dlpack__(max_version=None))
452+
pytest.raises(exception, lambda:
453+
a.__dlpack__(copy=False))
454+
pytest.raises(exception, lambda:
455+
a.__dlpack__(copy=True))
456+
pytest.raises(exception, lambda:
457+
a.__dlpack__(copy=None))

‎array_api_strict/tests/test_creation_functions.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from numpy.testing import assert_raises
44
import numpy as np
55

6+
import pytest
7+
68
from .. import all
79
from .._creation_functions import (
810
asarray,
911
arange,
1012
empty,
1113
empty_like,
1214
eye,
15+
from_dlpack,
1316
full,
1417
full_like,
1518
linspace,
@@ -21,7 +24,7 @@
2124
)
2225
from .._dtypes import float32, float64
2326
from .._array_object import Array, CPU_DEVICE
24-
27+
from .._flags import set_array_api_strict_flags
2528

2629
def test_asarray_errors():
2730
# Test various protections against incorrect usage
@@ -188,3 +191,24 @@ def test_meshgrid_dtype_errors():
188191
meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32))
189192

190193
assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64)))
194+
195+
196+
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
197+
def from_dlpack_2023_12(api_version):
198+
if api_version != '2022.12':
199+
with pytest.warns(UserWarning):
200+
set_array_api_strict_flags(api_version=api_version)
201+
else:
202+
set_array_api_strict_flags(api_version=api_version)
203+
204+
a = asarray([1., 2., 3.], dtype=float64)
205+
# Never an error
206+
capsule = a.__dlpack__()
207+
from_dlpack(capsule)
208+
209+
exception = NotImplementedError if api_version >= '2023.12' else ValueError
210+
pytest.raises(exception, lambda: from_dlpack(capsule, device=CPU_DEVICE))
211+
pytest.raises(exception, lambda: from_dlpack(capsule, device=None))
212+
pytest.raises(exception, lambda: from_dlpack(capsule, copy=False))
213+
pytest.raises(exception, lambda: from_dlpack(capsule, copy=True))
214+
pytest.raises(exception, lambda: from_dlpack(capsule, copy=None))

‎array_api_strict/tests/test_data_type_functions.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,68 @@
33
import pytest
44

55
from numpy.testing import assert_raises
6-
import array_api_strict as xp
76
import numpy as np
87

8+
from .._creation_functions import asarray
9+
from .._data_type_functions import astype, can_cast, isdtype
10+
from .._dtypes import (
11+
bool, int8, int16, uint8, float64,
12+
)
13+
from .._flags import set_array_api_strict_flags
14+
15+
916
@pytest.mark.parametrize(
1017
"from_, to, expected",
1118
[
12-
(xp.int8, xp.int16, True),
13-
(xp.int16, xp.int8, False),
14-
(xp.bool, xp.int8, False),
15-
(xp.asarray(0, dtype=xp.uint8), xp.int8, False),
19+
(int8, int16, True),
20+
(int16, int8, False),
21+
(bool, int8, False),
22+
(asarray(0, dtype=uint8), int8, False),
1623
],
1724
)
1825
def test_can_cast(from_, to, expected):
1926
"""
2027
can_cast() returns correct result
2128
"""
22-
assert xp.can_cast(from_, to) == expected
29+
assert can_cast(from_, to) == expected
2330

2431
def test_isdtype_strictness():
25-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64))
26-
assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8'))
32+
assert_raises(TypeError, lambda: isdtype(float64, 64))
33+
assert_raises(ValueError, lambda: isdtype(float64, 'f8'))
2734

28-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),)))
35+
assert_raises(TypeError, lambda: isdtype(float64, (('integral',),)))
2936
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
3037
warnings.simplefilter("always")
31-
xp.isdtype(xp.float64, np.object_)
38+
isdtype(float64, np.object_)
3239
assert len(w) == 1
3340
assert issubclass(w[-1].category, UserWarning)
3441

35-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None))
42+
assert_raises(TypeError, lambda: isdtype(float64, None))
3643
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
3744
warnings.simplefilter("always")
38-
xp.isdtype(xp.float64, np.float64)
45+
isdtype(float64, np.float64)
3946
assert len(w) == 1
4047
assert issubclass(w[-1].category, UserWarning)
48+
49+
50+
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
51+
def astype_device(api_version):
52+
if api_version != '2022.12':
53+
with pytest.warns(UserWarning):
54+
set_array_api_strict_flags(api_version=api_version)
55+
else:
56+
set_array_api_strict_flags(api_version=api_version)
57+
58+
a = asarray([1, 2, 3], dtype=int8)
59+
# Never an error
60+
astype(a, int16)
61+
62+
# Always an error
63+
astype(a, int16, device="cpu")
64+
65+
if api_version >= '2023.12':
66+
astype(a, int8, device=None)
67+
astype(a, int8, device=a.device)
68+
else:
69+
pytest.raises(TypeError, lambda: astype(a, int8, device=None))
70+
pytest.raises(TypeError, lambda: astype(a, int8, device=a.device))

‎array_api_strict/tests/test_elementwise_functions.py

Lines changed: 81 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from inspect import getfullargspec
1+
from inspect import getfullargspec, getmodule
22

33
from numpy.testing import assert_raises
44

@@ -10,79 +10,93 @@
1010
_floating_dtypes,
1111
_integer_dtypes,
1212
)
13+
from .._flags import set_array_api_strict_flags
1314

15+
import pytest
1416

1517
def nargs(func):
1618
return len(getfullargspec(func).args)
1719

1820

21+
elementwise_function_input_types = {
22+
"abs": "numeric",
23+
"acos": "floating-point",
24+
"acosh": "floating-point",
25+
"add": "numeric",
26+
"asin": "floating-point",
27+
"asinh": "floating-point",
28+
"atan": "floating-point",
29+
"atan2": "real floating-point",
30+
"atanh": "floating-point",
31+
"bitwise_and": "integer or boolean",
32+
"bitwise_invert": "integer or boolean",
33+
"bitwise_left_shift": "integer",
34+
"bitwise_or": "integer or boolean",
35+
"bitwise_right_shift": "integer",
36+
"bitwise_xor": "integer or boolean",
37+
"ceil": "real numeric",
38+
"clip": "real numeric",
39+
"conj": "complex floating-point",
40+
"copysign": "real floating-point",
41+
"cos": "floating-point",
42+
"cosh": "floating-point",
43+
"divide": "floating-point",
44+
"equal": "all",
45+
"exp": "floating-point",
46+
"expm1": "floating-point",
47+
"floor": "real numeric",
48+
"floor_divide": "real numeric",
49+
"greater": "real numeric",
50+
"greater_equal": "real numeric",
51+
"hypot": "real floating-point",
52+
"imag": "complex floating-point",
53+
"isfinite": "numeric",
54+
"isinf": "numeric",
55+
"isnan": "numeric",
56+
"less": "real numeric",
57+
"less_equal": "real numeric",
58+
"log": "floating-point",
59+
"logaddexp": "real floating-point",
60+
"log10": "floating-point",
61+
"log1p": "floating-point",
62+
"log2": "floating-point",
63+
"logical_and": "boolean",
64+
"logical_not": "boolean",
65+
"logical_or": "boolean",
66+
"logical_xor": "boolean",
67+
"maximum": "real numeric",
68+
"minimum": "real numeric",
69+
"multiply": "numeric",
70+
"negative": "numeric",
71+
"not_equal": "all",
72+
"positive": "numeric",
73+
"pow": "numeric",
74+
"real": "complex floating-point",
75+
"remainder": "real numeric",
76+
"round": "numeric",
77+
"sign": "numeric",
78+
"signbit": "real floating-point",
79+
"sin": "floating-point",
80+
"sinh": "floating-point",
81+
"sqrt": "floating-point",
82+
"square": "numeric",
83+
"subtract": "numeric",
84+
"tan": "floating-point",
85+
"tanh": "floating-point",
86+
"trunc": "real numeric",
87+
}
88+
89+
def test_missing_functions():
90+
# Ensure the above dictionary is complete.
91+
import array_api_strict._elementwise_functions as mod
92+
mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod]
93+
assert set(mod_funcs) == set(elementwise_function_input_types)
94+
1995
def test_function_types():
2096
# Test that every function accepts only the required input types. We only
2197
# test the negative cases here (error). The positive cases are tested in
2298
# the array API test suite.
2399

24-
elementwise_function_input_types = {
25-
"abs": "numeric",
26-
"acos": "floating-point",
27-
"acosh": "floating-point",
28-
"add": "numeric",
29-
"asin": "floating-point",
30-
"asinh": "floating-point",
31-
"atan": "floating-point",
32-
"atan2": "real floating-point",
33-
"atanh": "floating-point",
34-
"bitwise_and": "integer or boolean",
35-
"bitwise_invert": "integer or boolean",
36-
"bitwise_left_shift": "integer",
37-
"bitwise_or": "integer or boolean",
38-
"bitwise_right_shift": "integer",
39-
"bitwise_xor": "integer or boolean",
40-
"ceil": "real numeric",
41-
"conj": "complex floating-point",
42-
"cos": "floating-point",
43-
"cosh": "floating-point",
44-
"divide": "floating-point",
45-
"equal": "all",
46-
"exp": "floating-point",
47-
"expm1": "floating-point",
48-
"floor": "real numeric",
49-
"floor_divide": "real numeric",
50-
"greater": "real numeric",
51-
"greater_equal": "real numeric",
52-
"imag": "complex floating-point",
53-
"isfinite": "numeric",
54-
"isinf": "numeric",
55-
"isnan": "numeric",
56-
"less": "real numeric",
57-
"less_equal": "real numeric",
58-
"log": "floating-point",
59-
"logaddexp": "real floating-point",
60-
"log10": "floating-point",
61-
"log1p": "floating-point",
62-
"log2": "floating-point",
63-
"logical_and": "boolean",
64-
"logical_not": "boolean",
65-
"logical_or": "boolean",
66-
"logical_xor": "boolean",
67-
"multiply": "numeric",
68-
"negative": "numeric",
69-
"not_equal": "all",
70-
"positive": "numeric",
71-
"pow": "numeric",
72-
"real": "complex floating-point",
73-
"remainder": "real numeric",
74-
"round": "numeric",
75-
"sign": "numeric",
76-
"sin": "floating-point",
77-
"sinh": "floating-point",
78-
"sqrt": "floating-point",
79-
"square": "numeric",
80-
"subtract": "numeric",
81-
"tan": "floating-point",
82-
"tanh": "floating-point",
83-
"trunc": "real numeric",
84-
}
85-
86100
def _array_vals():
87101
for d in _integer_dtypes:
88102
yield asarray(1, dtype=d)
@@ -91,6 +105,10 @@ def _array_vals():
91105
for d in _floating_dtypes:
92106
yield asarray(1.0, dtype=d)
93107

108+
# Use the latest version of the standard so all functions are included
109+
with pytest.warns(UserWarning):
110+
set_array_api_strict_flags(api_version="2023.12")
111+
94112
for x in _array_vals():
95113
for func_name, types in elementwise_function_input_types.items():
96114
dtypes = _dtype_categories[types]

‎array_api_strict/tests/test_flags.py

Lines changed: 332 additions & 44 deletions
Large diffs are not rendered by default.

‎array_api_strict/tests/test_linalg.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pytest
2+
3+
from .._flags import set_array_api_strict_flags
4+
5+
import array_api_strict as xp
6+
7+
# TODO: Maybe all of these exceptions should be IndexError?
8+
9+
# Technically this is linear_algebra, not linalg, but it's simpler to keep
10+
# both of these tests together
11+
def test_vecdot_2023_12():
12+
# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >=
13+
# 0 behavior (which is primarily kept for backwards compatibility).
14+
15+
a = xp.ones((2, 3, 4, 5))
16+
b = xp.ones(( 3, 4, 1))
17+
18+
# 2022.12 behavior, which is to apply axis >= 0 after broadcasting
19+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
20+
assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5)
21+
assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5)
22+
# This is disallowed because the arrays must have the same values before
23+
# broadcasting
24+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1))
25+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4))
26+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3))
27+
28+
# Out-of-bounds axes even after broadcasting
29+
pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=4))
30+
pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=-5))
31+
32+
# negative axis behavior is unambiguous when it's within the bounds of
33+
# both arrays before broadcasting
34+
assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5)
35+
assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5)
36+
37+
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
38+
# min(x1.ndim, x2.ndim), which is unambiguous
39+
with pytest.warns(UserWarning):
40+
set_array_api_strict_flags(api_version='2023.12')
41+
42+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
43+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1))
44+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=2))
45+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3))
46+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1))
47+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4))
48+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=4))
49+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-5))
50+
51+
assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5)
52+
assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5)
53+
54+
@pytest.mark.parametrize('api_version', ['2021.12', '2022.12', '2023.12'])
55+
def test_cross(api_version):
56+
# This test tests everything that should be the same across all supported
57+
# API versions.
58+
59+
if api_version != '2022.12':
60+
with pytest.warns(UserWarning):
61+
set_array_api_strict_flags(api_version=api_version)
62+
else:
63+
set_array_api_strict_flags(api_version=api_version)
64+
65+
a = xp.ones((2, 4, 5, 3))
66+
b = xp.ones(( 4, 1, 3))
67+
assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3)
68+
69+
a = xp.ones((2, 4, 3, 5))
70+
b = xp.ones(( 4, 3, 1))
71+
assert xp.linalg.cross(a, b, axis=-2).shape == (2, 4, 3, 5)
72+
73+
# This is disallowed because the axes must equal 3 before broadcasting
74+
a = xp.ones((3, 2, 3, 5))
75+
b = xp.ones(( 2, 1, 1))
76+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-1))
77+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2))
78+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3))
79+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4))
80+
81+
# Out-of-bounds axes even after broadcasting
82+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4))
83+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5))
84+
85+
@pytest.mark.parametrize('api_version', ['2021.12', '2022.12'])
86+
def test_cross_2022_12(api_version):
87+
# Test the 2022.12 axis >= 0 behavior, which is primarily kept for
88+
# backwards compatibility. Note that unlike vecdot, array_api_strict
89+
# cross() never implemented the "after broadcasting" axis behavior, but
90+
# just reused NumPy cross(), which applies axes before broadcasting.
91+
if api_version != '2022.12':
92+
with pytest.warns(UserWarning):
93+
set_array_api_strict_flags(api_version=api_version)
94+
else:
95+
set_array_api_strict_flags(api_version=api_version)
96+
97+
a = xp.ones((3, 2, 4, 5))
98+
b = xp.ones((3, 2, 4, 1))
99+
assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5)
100+
101+
# ambiguous case
102+
a = xp.ones(( 3, 4, 5))
103+
b = xp.ones((3, 2, 4, 1))
104+
assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5)
105+
106+
def test_cross_2023_12():
107+
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
108+
# min(x1.ndim, x2.ndim), which is unambiguous
109+
with pytest.warns(UserWarning):
110+
set_array_api_strict_flags(api_version='2023.12')
111+
112+
a = xp.ones((3, 2, 4, 5))
113+
b = xp.ones((3, 2, 4, 1))
114+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0))
115+
116+
a = xp.ones(( 3, 4, 5))
117+
b = xp.ones((3, 2, 4, 1))
118+
pytest.raises(ValueError, lambda: xp. linalg.cross(a, b, axis=0))
119+
120+
a = xp.ones((2, 4, 5, 3))
121+
b = xp.ones(( 4, 1, 3))
122+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0))
123+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=1))
124+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=2))
125+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=3))
126+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2))
127+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3))
128+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4))
129+
130+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4))
131+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5))
132+
133+
assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3)

‎array_api_strict/tests/test_manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_reshape_copy():
2525
a = asarray(np.ones((2, 3)))
2626
b = reshape(a, (3, 2), copy=True)
2727
assert not np.shares_memory(a._array, b._array)
28-
28+
2929
a = asarray(np.ones((2, 3)))
3030
b = reshape(a, (3, 2), copy=False)
3131
assert np.shares_memory(a._array, b._array)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
3+
from .._flags import set_array_api_strict_flags
4+
5+
import array_api_strict as xp
6+
7+
@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace'])
8+
def test_sum_prod_trace_2023_12(func_name):
9+
# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes
10+
# with dtype=None
11+
if func_name == 'trace':
12+
func = getattr(xp.linalg, func_name)
13+
else:
14+
func = getattr(xp, func_name)
15+
16+
a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32)
17+
a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64)
18+
a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32)
19+
20+
assert func(a_real).dtype == xp.float64
21+
assert func(a_complex).dtype == xp.complex128
22+
assert func(a_int).dtype == xp.int64
23+
24+
with pytest.warns(UserWarning):
25+
set_array_api_strict_flags(api_version='2023.12')
26+
27+
assert func(a_real).dtype == xp.float32
28+
assert func(a_complex).dtype == xp.complex64
29+
assert func(a_int).dtype == xp.int64

‎docs/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Array API Strict Flags
1111
.. currentmodule:: array_api_strict
1212

1313
.. autofunction:: get_array_api_strict_flags
14+
15+
.. _set_array_api_strict_flags:
1416
.. autofunction:: set_array_api_strict_flags
1517
.. autofunction:: reset_array_api_strict_flags
1618
.. autoclass:: ArrayAPIStrictFlags

‎docs/index.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@ libraries. Consuming library code should use the
1515
support the array API. Rather, it is intended to be used in the test suites of
1616
consuming libraries to test their array API usage.
1717

18-
array-api-strict currently supports the 2022.12 version of the standard.
19-
2023.12 support is planned and is tracked by [this
20-
issue](https://github.com/data-apis/array-api-strict/issues/25).
18+
array-api-strict currently supports the
19+
[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12)
20+
version of the standard. Experimental
21+
[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12)
22+
support is implemented, [but must be enabled with a
23+
flag](set_array_api_strict_flags).
2124

2225
## Install
2326

@@ -179,9 +182,11 @@ issue, but this hasn't necessarily been tested thoroughly.
179182
function. array-api-strict currently implements all of these. In the
180183
future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7).
181184

182-
6. array-api-strict currently only supports the 2022.12 version of the array
183-
API standard. [Support for 2023.12 is
184-
planned](https://github.com/data-apis/array-api-strict/issues/25).
185+
6. array-api-strict currently uses the 2022.12 version of the array API
186+
standard. Support for 2023.12 is implemented but is still experimental and
187+
not fully tested. It can be enabled with
188+
[`array_api_strict.set_array_api_strict_flags(api_version='2023.12')`](set_array_api_strict_flags).
189+
185190

186191
(numpy.array_api)=
187192
## Relationship to `numpy.array_api`

‎pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[pytest]
2+
filterwarnings = error

0 commit comments

Comments
 (0)
Please sign in to comment.