Skip to content

Commit 6658d89

Browse files
authored
ENH: Add prod to masked_reductions (#33442)
1 parent 9c31732 commit 6658d89

File tree

7 files changed

+40
-38
lines changed

7 files changed

+40
-38
lines changed

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ Performance improvements
378378
sparse values from ``scipy.sparse`` matrices using the
379379
:meth:`DataFrame.sparse.from_spmatrix` constructor (:issue:`32821`,
380380
:issue:`32825`, :issue:`32826`, :issue:`32856`, :issue:`32858`).
381-
- Performance improvement in reductions (sum, min, max) for nullable (integer and boolean) dtypes (:issue:`30982`, :issue:`33261`).
381+
- Performance improvement in reductions (sum, prod, min, max) for nullable (integer and boolean) dtypes (:issue:`30982`, :issue:`33261`, :issue:`33442`).
382382

383383

384384
.. ---------------------------------------------------------------------------

pandas/core/array_algos/masked_reductions.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
for missing values.
44
"""
55

6+
from typing import Callable
7+
68
import numpy as np
79

810
from pandas._libs import missing as libmissing
@@ -11,14 +13,19 @@
1113
from pandas.core.nanops import check_below_min_count
1214

1315

14-
def sum(
15-
values: np.ndarray, mask: np.ndarray, skipna: bool = True, min_count: int = 0,
16+
def _sumprod(
17+
func: Callable,
18+
values: np.ndarray,
19+
mask: np.ndarray,
20+
skipna: bool = True,
21+
min_count: int = 0,
1622
):
1723
"""
18-
Sum for 1D masked array.
24+
Sum or product for 1D masked array.
1925
2026
Parameters
2127
----------
28+
func : np.sum or np.prod
2229
values : np.ndarray
2330
Numpy array with the values (can be of any dtype that support the
2431
operation).
@@ -31,23 +38,33 @@ def sum(
3138
``min_count`` non-NA values are present the result will be NA.
3239
"""
3340
if not skipna:
34-
if mask.any():
41+
if mask.any() or check_below_min_count(values.shape, None, min_count):
3542
return libmissing.NA
3643
else:
37-
if check_below_min_count(values.shape, None, min_count):
38-
return libmissing.NA
39-
return np.sum(values)
44+
return func(values)
4045
else:
4146
if check_below_min_count(values.shape, mask, min_count):
4247
return libmissing.NA
4348

4449
if _np_version_under1p17:
45-
return np.sum(values[~mask])
50+
return func(values[~mask])
4651
else:
47-
return np.sum(values, where=~mask)
52+
return func(values, where=~mask)
53+
54+
55+
def sum(values: np.ndarray, mask: np.ndarray, skipna: bool = True, min_count: int = 0):
56+
return _sumprod(
57+
np.sum, values=values, mask=mask, skipna=skipna, min_count=min_count
58+
)
4859

4960

50-
def _minmax(func, values: np.ndarray, mask: np.ndarray, skipna: bool = True):
61+
def prod(values: np.ndarray, mask: np.ndarray, skipna: bool = True, min_count: int = 0):
62+
return _sumprod(
63+
np.prod, values=values, mask=mask, skipna=skipna, min_count=min_count
64+
)
65+
66+
67+
def _minmax(func: Callable, values: np.ndarray, mask: np.ndarray, skipna: bool = True):
5168
"""
5269
Reduction for 1D masked array.
5370
@@ -63,18 +80,15 @@ def _minmax(func, values: np.ndarray, mask: np.ndarray, skipna: bool = True):
6380
Whether to skip NA.
6481
"""
6582
if not skipna:
66-
if mask.any():
83+
if mask.any() or not values.size:
84+
# min/max with empty array raise in numpy, pandas returns NA
6785
return libmissing.NA
6886
else:
69-
if values.size:
70-
return func(values)
71-
else:
72-
# min/max with empty array raise in numpy, pandas returns NA
73-
return libmissing.NA
87+
return func(values)
7488
else:
7589
subset = values[~mask]
7690
if subset.size:
77-
return func(values[~mask])
91+
return func(subset)
7892
else:
7993
# min/max with empty array raise in numpy, pandas returns NA
8094
return libmissing.NA

pandas/core/arrays/boolean.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from pandas.core.dtypes.dtypes import register_extension_dtype
2626
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
27-
from pandas.core.dtypes.missing import isna, notna
27+
from pandas.core.dtypes.missing import isna
2828

2929
from pandas.core import nanops, ops
3030
from pandas.core.array_algos import masked_reductions
@@ -686,7 +686,7 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
686686
data = self._data
687687
mask = self._mask
688688

689-
if name in {"sum", "min", "max"}:
689+
if name in {"sum", "prod", "min", "max"}:
690690
op = getattr(masked_reductions, name)
691691
return op(data, mask, skipna=skipna, **kwargs)
692692

@@ -700,12 +700,6 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
700700
if np.isnan(result):
701701
return libmissing.NA
702702

703-
# if we have numeric op that would result in an int, coerce to int if possible
704-
if name == "prod" and notna(result):
705-
int_result = np.int64(result)
706-
if int_result == result:
707-
result = int_result
708-
709703
return result
710704

711705
def _maybe_mask_result(self, result, mask, other, op_name: str):

pandas/core/arrays/integer.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
from pandas.core import nanops, ops
3030
from pandas.core.array_algos import masked_reductions
31-
import pandas.core.common as com
3231
from pandas.core.indexers import check_array_indexer
3332
from pandas.core.ops import invalid_comparison
3433
from pandas.core.ops.common import unpack_zerodim_and_defer
@@ -557,7 +556,7 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
557556
data = self._data
558557
mask = self._mask
559558

560-
if name in {"sum", "min", "max"}:
559+
if name in {"sum", "prod", "min", "max"}:
561560
op = getattr(masked_reductions, name)
562561
return op(data, mask, skipna=skipna, **kwargs)
563562

@@ -576,12 +575,6 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
576575
if name in ["any", "all"]:
577576
pass
578577

579-
# if we have a preservable numeric op,
580-
# provide coercion back to an integer type if possible
581-
elif name == "prod":
582-
# GH#31409 more performant than casting-then-checking
583-
result = com.cast_scalar_indexer(result)
584-
585578
return result
586579

587580
def _maybe_mask_result(self, result, mask, other, op_name: str):

pandas/tests/arrays/boolean/test_reduction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_reductions_return_types(dropna, data, all_numeric_reductions):
5252
if op == "sum":
5353
assert isinstance(getattr(s, op)(), np.int_)
5454
elif op == "prod":
55-
assert isinstance(getattr(s, op)(), np.int64)
55+
assert isinstance(getattr(s, op)(), np.int_)
5656
elif op in ("min", "max"):
5757
assert isinstance(getattr(s, op)(), np.bool_)
5858
else:

pandas/tests/arrays/integer/test_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_preserve_dtypes(op):
3434

3535
# op
3636
result = getattr(df.C, op)()
37-
if op in {"sum", "min", "max"}:
37+
if op in {"sum", "prod", "min", "max"}:
3838
assert isinstance(result, np.int64)
3939
else:
4040
assert isinstance(result, int)

pandas/tests/extension/test_integer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,10 @@ def check_reduce(self, s, op_name, skipna):
238238
# overwrite to ensure pd.NA is tested instead of np.nan
239239
# https://github.com/pandas-dev/pandas/issues/30958
240240
result = getattr(s, op_name)(skipna=skipna)
241-
expected = getattr(s.astype("float64"), op_name)(skipna=skipna)
242-
if np.isnan(expected):
241+
if not skipna and s.isna().any():
243242
expected = pd.NA
243+
else:
244+
expected = getattr(s.dropna().astype("int64"), op_name)(skipna=skipna)
244245
tm.assert_almost_equal(result, expected)
245246

246247

0 commit comments

Comments
 (0)