Skip to content

Commit cff293b

Browse files
authored
REF/BUG: cast back to datetimelike in pad/backfill (#40052)
1 parent 6fcc188 commit cff293b

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

pandas/core/missing.py

+34-32
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
"""
44
from __future__ import annotations
55

6-
from functools import partial
6+
from functools import (
7+
partial,
8+
wraps,
9+
)
710
from typing import (
811
TYPE_CHECKING,
912
Any,
1013
List,
1114
Optional,
1215
Set,
1316
Union,
17+
cast,
1418
)
1519

1620
import numpy as np
@@ -22,15 +26,13 @@
2226
from pandas._typing import (
2327
ArrayLike,
2428
Axis,
25-
DtypeObj,
29+
F,
2630
)
2731
from pandas.compat._optional import import_optional_dependency
2832

2933
from pandas.core.dtypes.cast import infer_dtype_from
3034
from pandas.core.dtypes.common import (
31-
ensure_float64,
3235
is_array_like,
33-
is_integer_dtype,
3436
is_numeric_v_string_like,
3537
needs_i8_conversion,
3638
)
@@ -674,54 +676,53 @@ def interpolate_2d(
674676
return result
675677

676678

677-
def _cast_values_for_fillna(values, dtype: DtypeObj, has_mask: bool):
678-
"""
679-
Cast values to a dtype that algos.pad and algos.backfill can handle.
680-
"""
681-
# TODO: for int-dtypes we make a copy, but for everything else this
682-
# alters the values in-place. Is this intentional?
679+
def _fillna_prep(values, mask=None):
680+
# boilerplate for _pad_1d, _backfill_1d, _pad_2d, _backfill_2d
683681

684-
if needs_i8_conversion(dtype):
685-
values = values.view(np.int64)
682+
if mask is None:
683+
mask = isna(values)
686684

687-
elif is_integer_dtype(values) and not has_mask:
688-
# NB: this check needs to come after the datetime64 check above
689-
# has_mask check to avoid casting i8 values that have already
690-
# been cast from PeriodDtype
691-
values = ensure_float64(values)
685+
mask = mask.view(np.uint8)
686+
return mask
692687

693-
return values
694688

689+
def _datetimelike_compat(func: F) -> F:
690+
"""
691+
Wrapper to handle datetime64 and timedelta64 dtypes.
692+
"""
695693

696-
def _fillna_prep(values, mask=None):
697-
# boilerplate for _pad_1d, _backfill_1d, _pad_2d, _backfill_2d
698-
dtype = values.dtype
694+
@wraps(func)
695+
def new_func(values, limit=None, mask=None):
696+
if needs_i8_conversion(values.dtype):
697+
if mask is None:
698+
# This needs to occur before casting to int64
699+
mask = isna(values)
699700

700-
has_mask = mask is not None
701-
if not has_mask:
702-
# This needs to occur before datetime/timedeltas are cast to int64
703-
mask = isna(values)
701+
result = func(values.view("i8"), limit=limit, mask=mask)
702+
return result.view(values.dtype)
704703

705-
values = _cast_values_for_fillna(values, dtype, has_mask)
704+
return func(values, limit=limit, mask=mask)
706705

707-
mask = mask.view(np.uint8)
708-
return values, mask
706+
return cast(F, new_func)
709707

710708

709+
@_datetimelike_compat
711710
def _pad_1d(values, limit=None, mask=None):
712-
values, mask = _fillna_prep(values, mask)
711+
mask = _fillna_prep(values, mask)
713712
algos.pad_inplace(values, mask, limit=limit)
714713
return values
715714

716715

716+
@_datetimelike_compat
717717
def _backfill_1d(values, limit=None, mask=None):
718-
values, mask = _fillna_prep(values, mask)
718+
mask = _fillna_prep(values, mask)
719719
algos.backfill_inplace(values, mask, limit=limit)
720720
return values
721721

722722

723+
@_datetimelike_compat
723724
def _pad_2d(values, limit=None, mask=None):
724-
values, mask = _fillna_prep(values, mask)
725+
mask = _fillna_prep(values, mask)
725726

726727
if np.all(values.shape):
727728
algos.pad_2d_inplace(values, mask, limit=limit)
@@ -731,8 +732,9 @@ def _pad_2d(values, limit=None, mask=None):
731732
return values
732733

733734

735+
@_datetimelike_compat
734736
def _backfill_2d(values, limit=None, mask=None):
735-
values, mask = _fillna_prep(values, mask)
737+
mask = _fillna_prep(values, mask)
736738

737739
if np.all(values.shape):
738740
algos.backfill_2d_inplace(values, mask, limit=limit)

0 commit comments

Comments
 (0)