Skip to content

Commit bae7482

Browse files
committed
Add multi-device support to stats and sets
1 parent ff37de7 commit bae7482

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

array_api_strict/_set_functions.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def unique_all(x: Array, /) -> UniqueAllResult:
5555
# See https://github.com/numpy/numpy/issues/20638
5656
inverse_indices = inverse_indices.reshape(x.shape)
5757
return UniqueAllResult(
58-
Array._new(values),
59-
Array._new(indices),
60-
Array._new(inverse_indices),
61-
Array._new(counts),
58+
Array._new(values, device=x.device),
59+
Array._new(indices, device=x.device),
60+
Array._new(inverse_indices, device=x.device),
61+
Array._new(counts, device=x.device),
6262
)
6363

6464

@@ -72,7 +72,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult:
7272
equal_nan=False,
7373
)
7474

75-
return UniqueCountsResult(*[Array._new(i) for i in res])
75+
return UniqueCountsResult(*[Array._new(i, device=x.device) for i in res])
7676

7777

7878
@requires_data_dependent_shapes
@@ -92,7 +92,8 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
9292
# np.unique() flattens inverse indices, but they need to share x's shape
9393
# See https://github.com/numpy/numpy/issues/20638
9494
inverse_indices = inverse_indices.reshape(x.shape)
95-
return UniqueInverseResult(Array._new(values), Array._new(inverse_indices))
95+
return UniqueInverseResult(Array._new(values, device=x.device),
96+
Array._new(inverse_indices, device=x.device))
9697

9798

9899
@requires_data_dependent_shapes
@@ -109,4 +110,4 @@ def unique_values(x: Array, /) -> Array:
109110
return_inverse=False,
110111
equal_nan=False,
111112
)
112-
return Array._new(res)
113+
return Array._new(res, device=x.device)

array_api_strict/_statistical_functions.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def cumulative_sum(
4444
if axis < 0:
4545
axis += x.ndim
4646
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
47-
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype))
47+
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
4848

4949
def max(
5050
x: Array,
@@ -55,7 +55,7 @@ def max(
5555
) -> Array:
5656
if x.dtype not in _real_numeric_dtypes:
5757
raise TypeError("Only real numeric dtypes are allowed in max")
58-
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
58+
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims), device=x.device)
5959

6060

6161
def mean(
@@ -67,7 +67,7 @@ def mean(
6767
) -> Array:
6868
if x.dtype not in _real_floating_dtypes:
6969
raise TypeError("Only real floating-point dtypes are allowed in mean")
70-
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
70+
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device)
7171

7272

7373
def min(
@@ -79,7 +79,7 @@ def min(
7979
) -> Array:
8080
if x.dtype not in _real_numeric_dtypes:
8181
raise TypeError("Only real numeric dtypes are allowed in min")
82-
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
82+
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device)
8383

8484

8585
def prod(
@@ -104,7 +104,7 @@ def prod(
104104
dtype = np.complex128
105105
else:
106106
dtype = dtype._np_dtype
107-
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
107+
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims), device=x.device)
108108

109109

110110
def std(
@@ -118,7 +118,7 @@ def std(
118118
# Note: the keyword argument correction is different here
119119
if x.dtype not in _real_floating_dtypes:
120120
raise TypeError("Only real floating-point dtypes are allowed in std")
121-
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
121+
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)
122122

123123

124124
def sum(
@@ -143,7 +143,7 @@ def sum(
143143
dtype = np.complex128
144144
else:
145145
dtype = dtype._np_dtype
146-
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
146+
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims), device=x.device)
147147

148148

149149
def var(
@@ -157,4 +157,4 @@ def var(
157157
# Note: the keyword argument correction is different here
158158
if x.dtype not in _real_floating_dtypes:
159159
raise TypeError("Only real floating-point dtypes are allowed in var")
160-
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))
160+
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)

0 commit comments

Comments
 (0)