Skip to content

Commit a043bb8

Browse files
authored
Merge pull request #20533 from asmeurer/array_api-prod-fix
BUG: Fix handling of the dtype parameter to numpy.array_api.prod() Original NumPy Commit: ab7a1927353ab9dd52e3f2f7a1a889ae790667b9
2 parents a52063e + 9410ad3 commit a043bb8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

array_api_strict/_statistical_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def prod(
6565
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
6666
# We need to do so here before computing the product to avoid overflow
6767
if dtype is None and x.dtype == float32:
68-
x = asarray(x, dtype=float64)
69-
return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims))
68+
dtype = float64
69+
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
7070

7171

7272
def std(

0 commit comments

Comments
 (0)