Skip to content

Commit aa390ec

Browse files
authored
REF: collect boilerplate in _datetimelike_compat (#37723)
1 parent 80749af commit aa390ec

File tree

1 file changed

+30
-23
lines changed

1 file changed

+30
-23
lines changed

pandas/core/nanops.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,32 @@ def _wrap_results(result, dtype: np.dtype, fill_value=None):
367367
return result
368368

369369

370+
def _datetimelike_compat(func):
371+
"""
372+
If we have datetime64 or timedelta64 values, ensure we have a correct
373+
mask before calling the wrapped function, then cast back afterwards.
374+
"""
375+
376+
@functools.wraps(func)
377+
def new_func(values, *, axis=None, skipna=True, mask=None, **kwargs):
378+
orig_values = values
379+
380+
datetimelike = values.dtype.kind in ["m", "M"]
381+
if datetimelike and mask is None:
382+
mask = isna(values)
383+
384+
result = func(values, axis=axis, skipna=skipna, mask=mask, **kwargs)
385+
386+
if datetimelike:
387+
result = _wrap_results(result, orig_values.dtype, fill_value=iNaT)
388+
if not skipna:
389+
result = _mask_datetimelike_result(result, axis, mask, orig_values)
390+
391+
return result
392+
393+
return new_func
394+
395+
370396
def _na_for_min_count(
371397
values: np.ndarray, axis: Optional[int]
372398
) -> Union[Scalar, np.ndarray]:
@@ -480,6 +506,7 @@ def nanall(
480506

481507

482508
@disallow("M8")
509+
@_datetimelike_compat
483510
def nansum(
484511
values: np.ndarray,
485512
*,
@@ -511,25 +538,18 @@ def nansum(
511538
>>> nanops.nansum(s)
512539
3.0
513540
"""
514-
orig_values = values
515-
516541
values, mask, dtype, dtype_max, _ = _get_values(
517542
values, skipna, fill_value=0, mask=mask
518543
)
519544
dtype_sum = dtype_max
520-
datetimelike = False
521545
if is_float_dtype(dtype):
522546
dtype_sum = dtype
523547
elif is_timedelta64_dtype(dtype):
524-
datetimelike = True
525548
dtype_sum = np.float64
526549

527550
the_sum = values.sum(axis, dtype=dtype_sum)
528551
the_sum = _maybe_null_out(the_sum, axis, mask, values.shape, min_count=min_count)
529552

530-
the_sum = _wrap_results(the_sum, dtype)
531-
if datetimelike and not skipna:
532-
the_sum = _mask_datetimelike_result(the_sum, axis, mask, orig_values)
533553
return the_sum
534554

535555

@@ -552,6 +572,7 @@ def _mask_datetimelike_result(
552572

553573
@disallow(PeriodDtype)
554574
@bottleneck_switch()
575+
@_datetimelike_compat
555576
def nanmean(
556577
values: np.ndarray,
557578
*,
@@ -583,18 +604,14 @@ def nanmean(
583604
>>> nanops.nanmean(s)
584605
1.5
585606
"""
586-
orig_values = values
587-
588607
values, mask, dtype, dtype_max, _ = _get_values(
589608
values, skipna, fill_value=0, mask=mask
590609
)
591610
dtype_sum = dtype_max
592611
dtype_count = np.float64
593612

594613
# not using needs_i8_conversion because that includes period
595-
datetimelike = False
596614
if dtype.kind in ["m", "M"]:
597-
datetimelike = True
598615
dtype_sum = np.float64
599616
elif is_integer_dtype(dtype):
600617
dtype_sum = np.float64
@@ -616,9 +633,6 @@ def nanmean(
616633
else:
617634
the_mean = the_sum / count if count > 0 else np.nan
618635

619-
the_mean = _wrap_results(the_mean, dtype)
620-
if datetimelike and not skipna:
621-
the_mean = _mask_datetimelike_result(the_mean, axis, mask, orig_values)
622636
return the_mean
623637

624638

@@ -875,7 +889,7 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None):
875889
# precision as the original values array.
876890
if is_float_dtype(dtype):
877891
result = result.astype(dtype)
878-
return _wrap_results(result, values.dtype)
892+
return result
879893

880894

881895
@disallow("M8", "m8")
@@ -930,6 +944,7 @@ def nansem(
930944

931945
def _nanminmax(meth, fill_value_typ):
932946
@bottleneck_switch(name="nan" + meth)
947+
@_datetimelike_compat
933948
def reduction(
934949
values: np.ndarray,
935950
*,
@@ -938,13 +953,10 @@ def reduction(
938953
mask: Optional[np.ndarray] = None,
939954
) -> Dtype:
940955

941-
orig_values = values
942956
values, mask, dtype, dtype_max, fill_value = _get_values(
943957
values, skipna, fill_value_typ=fill_value_typ, mask=mask
944958
)
945959

946-
datetimelike = orig_values.dtype.kind in ["m", "M"]
947-
948960
if (axis is not None and values.shape[axis] == 0) or values.size == 0:
949961
try:
950962
result = getattr(values, meth)(axis, dtype=dtype_max)
@@ -954,12 +966,7 @@ def reduction(
954966
else:
955967
result = getattr(values, meth)(axis)
956968

957-
result = _wrap_results(result, dtype, fill_value)
958969
result = _maybe_null_out(result, axis, mask, values.shape)
959-
960-
if datetimelike and not skipna:
961-
result = _mask_datetimelike_result(result, axis, mask, orig_values)
962-
963970
return result
964971

965972
return reduction

0 commit comments

Comments
 (0)