Skip to content

Commit 6ee20b0

Browse files
authored
REF: implement _wrap_reduction_result (#37660)
1 parent a347bc1 commit 6ee20b0

File tree

7 files changed

+61
-36
lines changed

7 files changed

+61
-36
lines changed

pandas/core/arrays/_mixins.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Sequence, TypeVar
1+
from typing import Any, Optional, Sequence, TypeVar
22

33
import numpy as np
44

@@ -255,6 +255,11 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
255255
msg = f"'{type(self).__name__}' does not implement reduction '{name}'"
256256
raise TypeError(msg)
257257

258+
def _wrap_reduction_result(self, axis: Optional[int], result):
259+
if axis is None or self.ndim == 1:
260+
return self._box_func(result)
261+
return self._from_backing_data(result)
262+
258263
# ------------------------------------------------------------------------
259264

260265
def __repr__(self) -> str:

pandas/core/arrays/categorical.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1957,7 +1957,7 @@ def min(self, *, skipna=True, **kwargs):
19571957
return np.nan
19581958
else:
19591959
pointer = self._codes.min()
1960-
return self.categories[pointer]
1960+
return self._wrap_reduction_result(None, pointer)
19611961

19621962
@deprecate_kwarg(old_arg_name="numeric_only", new_arg_name="skipna")
19631963
def max(self, *, skipna=True, **kwargs):
@@ -1993,7 +1993,7 @@ def max(self, *, skipna=True, **kwargs):
19931993
return np.nan
19941994
else:
19951995
pointer = self._codes.max()
1996-
return self.categories[pointer]
1996+
return self._wrap_reduction_result(None, pointer)
19971997

19981998
def mode(self, dropna=True):
19991999
"""

pandas/core/arrays/datetimelike.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -1283,9 +1283,7 @@ def min(self, *, axis=None, skipna=True, **kwargs):
12831283
return self._from_backing_data(result)
12841284

12851285
result = nanops.nanmin(self._ndarray, axis=axis, skipna=skipna)
1286-
if lib.is_scalar(result):
1287-
return self._box_func(result)
1288-
return self._from_backing_data(result)
1286+
return self._wrap_reduction_result(axis, result)
12891287

12901288
def max(self, *, axis=None, skipna=True, **kwargs):
12911289
"""
@@ -1316,9 +1314,7 @@ def max(self, *, axis=None, skipna=True, **kwargs):
13161314
return self._from_backing_data(result)
13171315

13181316
result = nanops.nanmax(self._ndarray, axis=axis, skipna=skipna)
1319-
if lib.is_scalar(result):
1320-
return self._box_func(result)
1321-
return self._from_backing_data(result)
1317+
return self._wrap_reduction_result(axis, result)
13221318

13231319
def mean(self, *, skipna=True, axis: Optional[int] = 0):
13241320
"""
@@ -1357,9 +1353,7 @@ def mean(self, *, skipna=True, axis: Optional[int] = 0):
13571353
result = nanops.nanmean(
13581354
self._ndarray, axis=axis, skipna=skipna, mask=self.isna()
13591355
)
1360-
if axis is None or self.ndim == 1:
1361-
return self._box_func(result)
1362-
return self._from_backing_data(result)
1356+
return self._wrap_reduction_result(axis, result)
13631357

13641358
def median(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
13651359
nv.validate_median((), kwargs)
@@ -1378,9 +1372,7 @@ def median(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
13781372
return self._from_backing_data(result)
13791373

13801374
result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna)
1381-
if axis is None or self.ndim == 1:
1382-
return self._box_func(result)
1383-
return self._from_backing_data(result)
1375+
return self._wrap_reduction_result(axis, result)
13841376

13851377

13861378
class DatelikeOps(DatetimeLikeArrayMixin):

pandas/core/arrays/numpy_.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pandas.core.dtypes.missing import isna
1313

1414
from pandas.core import nanops, ops
15-
from pandas.core.array_algos import masked_reductions
1615
from pandas.core.arraylike import OpsMixin
1716
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
1817
from pandas.core.strings.object_array import ObjectStringArrayMixin
@@ -275,83 +274,96 @@ def _values_for_factorize(self) -> Tuple[np.ndarray, int]:
275274

276275
def any(self, *, axis=None, out=None, keepdims=False, skipna=True):
277276
nv.validate_any((), dict(out=out, keepdims=keepdims))
278-
return nanops.nanany(self._ndarray, axis=axis, skipna=skipna)
277+
result = nanops.nanany(self._ndarray, axis=axis, skipna=skipna)
278+
return self._wrap_reduction_result(axis, result)
279279

280280
def all(self, *, axis=None, out=None, keepdims=False, skipna=True):
281281
nv.validate_all((), dict(out=out, keepdims=keepdims))
282-
return nanops.nanall(self._ndarray, axis=axis, skipna=skipna)
282+
result = nanops.nanall(self._ndarray, axis=axis, skipna=skipna)
283+
return self._wrap_reduction_result(axis, result)
283284

284-
def min(self, *, skipna: bool = True, **kwargs) -> Scalar:
285+
def min(self, *, axis=None, skipna: bool = True, **kwargs) -> Scalar:
285286
nv.validate_min((), kwargs)
286-
return masked_reductions.min(
287-
values=self.to_numpy(), mask=self.isna(), skipna=skipna
287+
result = nanops.nanmin(
288+
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
288289
)
290+
return self._wrap_reduction_result(axis, result)
289291

290-
def max(self, *, skipna: bool = True, **kwargs) -> Scalar:
292+
def max(self, *, axis=None, skipna: bool = True, **kwargs) -> Scalar:
291293
nv.validate_max((), kwargs)
292-
return masked_reductions.max(
293-
values=self.to_numpy(), mask=self.isna(), skipna=skipna
294+
result = nanops.nanmax(
295+
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
294296
)
297+
return self._wrap_reduction_result(axis, result)
295298

296299
def sum(self, *, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar:
297300
nv.validate_sum((), kwargs)
298-
return nanops.nansum(
301+
result = nanops.nansum(
299302
self._ndarray, axis=axis, skipna=skipna, min_count=min_count
300303
)
304+
return self._wrap_reduction_result(axis, result)
301305

302306
def prod(self, *, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar:
303307
nv.validate_prod((), kwargs)
304-
return nanops.nanprod(
308+
result = nanops.nanprod(
305309
self._ndarray, axis=axis, skipna=skipna, min_count=min_count
306310
)
311+
return self._wrap_reduction_result(axis, result)
307312

308313
def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, skipna=True):
309314
nv.validate_mean((), dict(dtype=dtype, out=out, keepdims=keepdims))
310-
return nanops.nanmean(self._ndarray, axis=axis, skipna=skipna)
315+
result = nanops.nanmean(self._ndarray, axis=axis, skipna=skipna)
316+
return self._wrap_reduction_result(axis, result)
311317

312318
def median(
313319
self, *, axis=None, out=None, overwrite_input=False, keepdims=False, skipna=True
314320
):
315321
nv.validate_median(
316322
(), dict(out=out, overwrite_input=overwrite_input, keepdims=keepdims)
317323
)
318-
return nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna)
324+
result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna)
325+
return self._wrap_reduction_result(axis, result)
319326

320327
def std(
321328
self, *, axis=None, dtype=None, out=None, ddof=1, keepdims=False, skipna=True
322329
):
323330
nv.validate_stat_ddof_func(
324331
(), dict(dtype=dtype, out=out, keepdims=keepdims), fname="std"
325332
)
326-
return nanops.nanstd(self._ndarray, axis=axis, skipna=skipna, ddof=ddof)
333+
result = nanops.nanstd(self._ndarray, axis=axis, skipna=skipna, ddof=ddof)
334+
return self._wrap_reduction_result(axis, result)
327335

328336
def var(
329337
self, *, axis=None, dtype=None, out=None, ddof=1, keepdims=False, skipna=True
330338
):
331339
nv.validate_stat_ddof_func(
332340
(), dict(dtype=dtype, out=out, keepdims=keepdims), fname="var"
333341
)
334-
return nanops.nanvar(self._ndarray, axis=axis, skipna=skipna, ddof=ddof)
342+
result = nanops.nanvar(self._ndarray, axis=axis, skipna=skipna, ddof=ddof)
343+
return self._wrap_reduction_result(axis, result)
335344

336345
def sem(
337346
self, *, axis=None, dtype=None, out=None, ddof=1, keepdims=False, skipna=True
338347
):
339348
nv.validate_stat_ddof_func(
340349
(), dict(dtype=dtype, out=out, keepdims=keepdims), fname="sem"
341350
)
342-
return nanops.nansem(self._ndarray, axis=axis, skipna=skipna, ddof=ddof)
351+
result = nanops.nansem(self._ndarray, axis=axis, skipna=skipna, ddof=ddof)
352+
return self._wrap_reduction_result(axis, result)
343353

344354
def kurt(self, *, axis=None, dtype=None, out=None, keepdims=False, skipna=True):
345355
nv.validate_stat_ddof_func(
346356
(), dict(dtype=dtype, out=out, keepdims=keepdims), fname="kurt"
347357
)
348-
return nanops.nankurt(self._ndarray, axis=axis, skipna=skipna)
358+
result = nanops.nankurt(self._ndarray, axis=axis, skipna=skipna)
359+
return self._wrap_reduction_result(axis, result)
349360

350361
def skew(self, *, axis=None, dtype=None, out=None, keepdims=False, skipna=True):
351362
nv.validate_stat_ddof_func(
352363
(), dict(dtype=dtype, out=out, keepdims=keepdims), fname="skew"
353364
)
354-
return nanops.nanskew(self._ndarray, axis=axis, skipna=skipna)
365+
result = nanops.nanskew(self._ndarray, axis=axis, skipna=skipna)
366+
return self._wrap_reduction_result(axis, result)
355367

356368
# ------------------------------------------------------------------------
357369
# Additional Methods

pandas/core/arrays/string_.py

+17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numpy as np
44

55
from pandas._libs import lib, missing as libmissing
6+
from pandas._typing import Scalar
7+
from pandas.compat.numpy import function as nv
68

79
from pandas.core.dtypes.base import ExtensionDtype, register_extension_dtype
810
from pandas.core.dtypes.common import (
@@ -15,6 +17,7 @@
1517
)
1618

1719
from pandas.core import ops
20+
from pandas.core.array_algos import masked_reductions
1821
from pandas.core.arrays import IntegerArray, PandasArray
1922
from pandas.core.arrays.integer import _IntegerDtype
2023
from pandas.core.construction import extract_array
@@ -301,6 +304,20 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
301304

302305
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")
303306

307+
def min(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
308+
nv.validate_min((), kwargs)
309+
result = masked_reductions.min(
310+
values=self.to_numpy(), mask=self.isna(), skipna=skipna
311+
)
312+
return self._wrap_reduction_result(axis, result)
313+
314+
def max(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
315+
nv.validate_max((), kwargs)
316+
result = masked_reductions.max(
317+
values=self.to_numpy(), mask=self.isna(), skipna=skipna
318+
)
319+
return self._wrap_reduction_result(axis, result)
320+
304321
def value_counts(self, dropna=False):
305322
from pandas import value_counts
306323

pandas/core/arrays/timedeltas.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,7 @@ def sum(
381381
result = nanops.nansum(
382382
self._ndarray, axis=axis, skipna=skipna, min_count=min_count
383383
)
384-
if axis is None or self.ndim == 1:
385-
return self._box_func(result)
386-
return self._from_backing_data(result)
384+
return self._wrap_reduction_result(axis, result)
387385

388386
def std(
389387
self,

pandas/core/nanops.py

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def _wrap_results(result, dtype: DtypeObj, fill_value=None):
344344
assert not isna(fill_value), "Expected non-null fill_value"
345345
if result == fill_value:
346346
result = np.nan
347+
347348
if tz is not None:
348349
# we get here e.g. via nanmean when we call it on a DTA[tz]
349350
result = Timestamp(result, tz=tz)

0 commit comments

Comments
 (0)