Skip to content

Commit 2c32ae3

Browse files
committed
ENH: Array API 2024.12: binary ops vs. Python scalars
1 parent e3e9a83 commit 2c32ae3

File tree

6 files changed

+241
-19
lines changed

6 files changed

+241
-19
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
56+
"array-api": ("https://data-apis.org/array-api/draft", None),
5657
"jax": ("https://jax.readthedocs.io/en/latest", None),
5758
}
5859

src/array_api_extra/_delegation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def isclose(
5252
5353
Parameters
5454
----------
55-
a, b : Array
56-
Input arrays to compare.
55+
a, b : Array | int | float | complex | bool
56+
Input objects to compare. At least one must be an Array API object.
5757
rtol : array_like, optional
5858
The relative tolerance parameter (see Notes).
5959
atol : array_like, optional

src/array_api_extra/_lib/_funcs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ._at import at
1313
from ._utils import _compat, _helpers
1414
from ._utils._compat import array_namespace, is_jax_array
15+
from ._utils._helpers import asarrays
1516
from ._utils._typing import Array
1617

1718
__all__ = [
@@ -315,6 +316,7 @@ def isclose(
315316
xp: ModuleType,
316317
) -> Array: # numpydoc ignore=PR01,RT01
317318
"""See docstring in array_api_extra._delegation."""
319+
a, b = asarrays(a, b, xp=xp)
318320

319321
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
320322
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
@@ -350,8 +352,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
350352
351353
Parameters
352354
----------
353-
a, b : array
354-
Input arrays.
355+
a, b : Array | int | float | complex
356+
Input arrays or scalars. At least one must be an Array API object.
355357
xp : array_namespace, optional
356358
The standard-compatible namespace for `a` and `b`. Default: infer.
357359
@@ -414,10 +416,10 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
414416
"""
415417
if xp is None:
416418
xp = array_namespace(a, b)
419+
a, b = asarrays(a, b, xp=xp)
417420

418-
b = xp.asarray(b)
419421
singletons = (1,) * (b.ndim - a.ndim)
420-
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)
422+
a = xp.broadcast_to(a, singletons + a.shape)
421423

422424
nd_b, nd_a = b.ndim, a.ndim
423425
nd_max = max(nd_b, nd_a)
@@ -577,6 +579,7 @@ def setdiff1d(
577579
"""
578580
if xp is None:
579581
xp = array_namespace(x1, x2)
582+
x1, x2 = asarrays(x1, x2, xp=xp)
580583

581584
if assume_unique:
582585
x1 = xp.reshape(x1, (-1,))

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from __future__ import annotations
55

66
from types import ModuleType
7+
from typing import cast
78

89
from . import _compat
10+
from ._compat import is_array_api_obj, is_numpy_array
911
from ._typing import Array
1012

1113
__all__ = ["in1d", "mean"]
@@ -91,3 +93,85 @@ def mean(
9193
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
9294
return mean_real + (mean_imag * xp.asarray(1j))
9395
return xp.mean(x, axis=axis, keepdims=keepdims)
96+
97+
98+
def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
99+
"""Return True if `x` is a Python scalar, False otherwise."""
100+
# isinstance(x, float) returns True for np.float64
101+
# isinstance(x, complex) returns True for np.complex128
102+
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)
103+
104+
105+
def asarrays(
106+
a: Array | int | float | complex | bool,
107+
b: Array | int | float | complex | bool,
108+
xp: ModuleType,
109+
) -> tuple[Array, Array]:
110+
"""
111+
Ensure both `a` and `b` are arrays.
112+
113+
If `b` is a python scalar, it is converted to the same dtype as `a`, and vice versa.
114+
115+
Behavior is not specified when mixing a Python ``float`` and an array with an
116+
integer data type; this may give ``float32``, ``float64``, or raise an exception.
117+
Behavior is implementation-specific.
118+
119+
Similarly, behavior is not specified when mixing a Python ``complex`` and an array
120+
with a real-valued data type; this may give ``complex64``, ``complex128``, or raise
121+
an exception. Behavior is implementation-specific.
122+
123+
Parameters
124+
----------
125+
a, b : Array | int | float | complex | bool
126+
Input arrays or scalars. At least one must be an array.
127+
xp : ModuleType
128+
The array API namespace.
129+
130+
Returns
131+
-------
132+
Array, Array
133+
The input arrays, possibly converted to arrays if they were scalars.
134+
135+
See Also
136+
--------
137+
mixing-arrays-with-python-scalars : Array API specification for the behavior.
138+
"""
139+
a_scalar = is_python_scalar(a)
140+
b_scalar = is_python_scalar(b)
141+
if not a_scalar and not b_scalar:
142+
return a, b # This includes misc. malformed input e.g. str
143+
144+
swap = False
145+
if a_scalar:
146+
swap = True
147+
b, a = a, b
148+
149+
if is_array_api_obj(a):
150+
# a is an Array API object
151+
# b is a int | float | complex | bool
152+
153+
# pyright doesn't like it if you reuse the same variable name
154+
xa = cast(Array, a)
155+
156+
# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
157+
same_dtype = {
158+
bool: "bool",
159+
int: ("integral", "real floating", "complex floating"),
160+
float: ("real floating", "complex floating"),
161+
complex: "complex floating",
162+
}
163+
kind = same_dtype[type(b)] # type: ignore[index]
164+
if xp.isdtype(xa.dtype, kind):
165+
xb = xp.asarray(b, dtype=xa.dtype)
166+
else:
167+
# Undefined behaviour. Let the function deal with it, if it can.
168+
xb = xp.asarray(b)
169+
170+
else:
171+
# Neither a nor b are Array API objects.
172+
# Note: we can only reach this point when one explicitly passes
173+
# xp=xp to the calling function; otherwise we fail earlier on
174+
# array_namespace(a, b).
175+
xa, xb = xp.asarray(a), xp.asarray(b)
176+
177+
return (xb, xa) if swap else (xa, xb)

tests/test_funcs.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,24 @@ def test_none_shape_bool(self, xp: ModuleType):
394394
a = a[a]
395395
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
396396

397+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
398+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
399+
def test_python_scalar(self, xp: ModuleType):
400+
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
401+
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))
402+
xp_assert_equal(isclose(0.0, a), xp.asarray([True, False]))
403+
404+
a = xp.asarray([0, 1], dtype=xp.int16)
405+
xp_assert_equal(isclose(a, 0), xp.asarray([True, False]))
406+
xp_assert_equal(isclose(0, a), xp.asarray([True, False]))
407+
408+
xp_assert_equal(isclose(0, 0, xp=xp), xp.asarray(True))
409+
xp_assert_equal(isclose(0, 1, xp=xp), xp.asarray(False))
410+
411+
def test_all_python_scalars(self):
412+
with pytest.raises(TypeError, match="Unrecognized"):
413+
isclose(0, 0)
414+
397415
def test_xp(self, xp: ModuleType):
398416
a = xp.asarray([0.0, 0.0])
399417
b = xp.asarray([1e-9, 1e-4])
@@ -406,30 +424,22 @@ def test_basic(self, xp: ModuleType):
406424
# Using 0-dimensional array
407425
a = xp.asarray(1)
408426
b = xp.asarray([[1, 2], [3, 4]])
409-
k = xp.asarray([[1, 2], [3, 4]])
410-
xp_assert_equal(kron(a, b), k)
411-
a = xp.asarray([[1, 2], [3, 4]])
412-
b = xp.asarray(1)
413-
xp_assert_equal(kron(a, b), k)
427+
xp_assert_equal(kron(a, b), b)
428+
xp_assert_equal(kron(b, a), b)
414429

415430
# Using 1-dimensional array
416431
a = xp.asarray([3])
417432
b = xp.asarray([[1, 2], [3, 4]])
418433
k = xp.asarray([[3, 6], [9, 12]])
419434
xp_assert_equal(kron(a, b), k)
420-
a = xp.asarray([[1, 2], [3, 4]])
421-
b = xp.asarray([3])
422-
xp_assert_equal(kron(a, b), k)
435+
xp_assert_equal(kron(b, a), k)
423436

424437
# Using 3-dimensional array
425438
a = xp.asarray([[[1]], [[2]]])
426439
b = xp.asarray([[1, 2], [3, 4]])
427440
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
428441
xp_assert_equal(kron(a, b), k)
429-
a = xp.asarray([[1, 2], [3, 4]])
430-
b = xp.asarray([[[1]], [[2]]])
431-
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
432-
xp_assert_equal(kron(a, b), k)
442+
xp_assert_equal(kron(b, a), k)
433443

434444
def test_kron_smoke(self, xp: ModuleType):
435445
a = xp.ones((3, 3))
@@ -467,6 +477,18 @@ def test_kron_shape(
467477
k = kron(a, b)
468478
assert k.shape == expected_shape
469479

480+
def test_python_scalar(self, xp: ModuleType):
481+
a = 1
482+
# Test no dtype promotion to xp.asarray(a); use b.dtype
483+
b = xp.asarray([[1, 2], [3, 4]], dtype=xp.int16)
484+
xp_assert_equal(kron(a, b), b)
485+
xp_assert_equal(kron(b, a), b)
486+
xp_assert_equal(kron(1, 1, xp=xp), xp.asarray(1))
487+
488+
def test_all_python_scalars(self):
489+
with pytest.raises(TypeError, match="Unrecognized"):
490+
kron(1, 1)
491+
470492
def test_device(self, xp: ModuleType, device: Device):
471493
x1 = xp.asarray([1, 2, 3], device=device)
472494
x2 = xp.asarray([4, 5], device=device)
@@ -594,6 +616,28 @@ def test_shapes(
594616
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
595617
xp_assert_equal(actual, xp.empty((0,)))
596618

619+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
620+
@pytest.mark.parametrize("assume_unique", [True, False])
621+
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
622+
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
623+
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
624+
x2 = 3
625+
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
626+
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
627+
628+
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
629+
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
630+
631+
xp_assert_equal(
632+
setdiff1d(0, 0, assume_unique=assume_unique, xp=xp),
633+
xp.asarray([0])[:0], # Default int dtype for backend
634+
)
635+
636+
@pytest.mark.parametrize("assume_unique", [True, False])
637+
def test_all_python_scalars(self, assume_unique: bool):
638+
with pytest.raises(TypeError, match="Unrecognized"):
639+
setdiff1d(0, 0, assume_unique=assume_unique)
640+
597641
def test_device(self, xp: ModuleType, device: Device):
598642
x1 = xp.asarray([3, 8, 20], device=device)
599643
x2 = xp.asarray([2, 3, 4], device=device)

tests/test_utils.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from types import ModuleType
22

3+
import numpy as np
34
import pytest
45

56
from array_api_extra._lib import Backend
67
from array_api_extra._lib._testing import xp_assert_equal
78
from array_api_extra._lib._utils._compat import device as get_device
8-
from array_api_extra._lib._utils._helpers import in1d
9+
from array_api_extra._lib._utils._helpers import asarrays, in1d
910
from array_api_extra._lib._utils._typing import Device
1011
from array_api_extra.testing import lazy_xp_function
1112

@@ -45,3 +46,92 @@ def test_xp(self, xp: ModuleType):
4546
expected = xp.asarray([True, False])
4647
actual = in1d(x1, x2, xp=xp)
4748
xp_assert_equal(actual, expected)
49+
50+
51+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
52+
@pytest.mark.parametrize(
53+
("dtype", "b", "defined"),
54+
[
55+
# Well-defined cases of dtype promotion from Python scalar to Array
56+
# bool vs. bool
57+
("bool", True, True),
58+
# int vs. xp.*int*, xp.float*, xp.complex*
59+
("int16", 1, True),
60+
("uint8", 1, True),
61+
("float32", 1, True),
62+
("complex64", 1, True),
63+
# float vs. xp.float, xp.complex
64+
("float32", 1.0, True),
65+
("complex64", 1.0, True),
66+
# complex vs. xp.complex
67+
("complex64", 1.0j, True),
68+
# Undefined cases
69+
("bool", 1, False),
70+
("int64", 1.0, False),
71+
("float64", 1.0j, False),
72+
],
73+
)
74+
def test_asarrays_array_vs_scalar(
75+
dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
76+
):
77+
a = xp.asarray(1, dtype=getattr(xp, dtype))
78+
79+
xa, xb = asarrays(a, b, xp)
80+
assert xa.dtype == a.dtype
81+
if defined:
82+
assert xb.dtype == a.dtype
83+
else:
84+
assert xb.dtype == xp.asarray(b).dtype
85+
86+
xbr, xar = asarrays(b, a, xp)
87+
assert xar.dtype == xa.dtype
88+
assert xbr.dtype == xb.dtype
89+
90+
91+
def test_asarrays_scalar_vs_scalar(xp: ModuleType):
92+
a, b = asarrays(1, 2.2, xp=xp)
93+
assert a.dtype == xp.asarray(1).dtype # Default dtype
94+
assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted
95+
96+
97+
ALL_TYPES = (
98+
"int8",
99+
"int16",
100+
"int32",
101+
"int64",
102+
"uint8",
103+
"uint16",
104+
"uint32",
105+
"uint64",
106+
"float32",
107+
"float64",
108+
"complex64",
109+
"complex128",
110+
"bool",
111+
)
112+
113+
114+
@pytest.mark.parametrize("a_type", ALL_TYPES)
115+
@pytest.mark.parametrize("b_type", ALL_TYPES)
116+
def test_asarrays_array_vs_array(a_type: str, b_type: str, xp: ModuleType):
117+
"""
118+
Test that when both inputs of asarray are already Array API objects,
119+
they are returned unchanged.
120+
"""
121+
a = xp.asarray(1, dtype=getattr(xp, a_type))
122+
b = xp.asarray(1, dtype=getattr(xp, b_type))
123+
xa, xb = asarrays(a, b, xp)
124+
assert xa.dtype == a.dtype
125+
assert xb.dtype == b.dtype
126+
127+
128+
@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
129+
def test_asarrays_numpy_generics(dtype: type):
130+
"""
131+
Test special case of np.float64 and np.complex128,
132+
which are subclasses of float and complex.
133+
"""
134+
a = dtype(0)
135+
xa, xb = asarrays(a, 0, xp=np)
136+
assert xa.dtype == dtype
137+
assert xb.dtype == dtype

0 commit comments

Comments
 (0)