Skip to content

Commit 9410ad3

Browse files
committed
BUG: Fix handling of the dtype parameter to numpy.array_api.prod()
Original NumPy Commit: 19a398ae78c3f35ce3d29d87b35da059558d72b4
1 parent 472e03f commit 9410ad3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

array_api_strict/_statistical_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def prod(
6565
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
6666
# We need to do so here before computing the product to avoid overflow
6767
if dtype is None and x.dtype == float32:
68-
x = asarray(x, dtype=float64)
69-
return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims))
68+
dtype = float64
69+
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
7070

7171

7272
def std(

0 commit comments

Comments
 (0)