Skip to content

Commit e4b6bfe

Browse files
authored
Merge pull request #110 from ev-br/cmplx_mean
ENH: allow mean(complex) in 2024.12
2 parents def141d + 4ac9255 commit e4b6bfe

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

array_api_strict/_statistical_functions.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._dtypes import (
44
_real_floating_dtypes,
55
_real_numeric_dtypes,
6+
_floating_dtypes,
67
_numeric_dtypes,
78
)
89
from ._array_object import Array
@@ -65,8 +66,14 @@ def mean(
6566
axis: Optional[Union[int, Tuple[int, ...]]] = None,
6667
keepdims: bool = False,
6768
) -> Array:
68-
if x.dtype not in _real_floating_dtypes:
69-
raise TypeError("Only real floating-point dtypes are allowed in mean")
69+
70+
if get_array_api_strict_flags()['api_version'] > '2023.12':
71+
allowed_dtypes = _floating_dtypes
72+
else:
73+
allowed_dtypes = _real_floating_dtypes
74+
75+
if x.dtype not in allowed_dtypes:
76+
raise TypeError("Only floating-point dtypes are allowed in mean")
7077
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device)
7178

7279

array_api_strict/tests/test_statistical_functions.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cmath
12
import pytest
23

34
from .._flags import set_array_api_strict_flags
@@ -37,3 +38,22 @@ def test_sum_prod_trace_2023_12(func_name):
3738
assert func(a_real).dtype == xp.float32
3839
assert func(a_complex).dtype == xp.complex64
3940
assert func(a_int).dtype == xp.int64
41+
42+
43+
# mean(complex-valued array) is allowed from 2024.12 onwards
44+
def test_mean_complex():
45+
a = xp.asarray([1j, 2j, 3j])
46+
47+
set_array_api_strict_flags(api_version='2023.12')
48+
with pytest.raises(TypeError):
49+
xp.mean(a)
50+
51+
with pytest.warns(UserWarning):
52+
set_array_api_strict_flags(api_version='2024.12')
53+
m = xp.mean(a)
54+
assert cmath.isclose(complex(m), 2j)
55+
56+
# mean of integer arrays is still not allowed
57+
with pytest.raises(TypeError):
58+
xp.mean(xp.arange(3))
59+

0 commit comments

Comments
 (0)