Skip to content

Commit 7a9975f

Browse files
authored
BUG: DataFrame.shift with axis=1 and mismatched fill_value (#44564)
1 parent fd95026 commit 7a9975f

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ Other
753753
- Bug in :meth:`Series.to_frame` and :meth:`Index.to_frame` ignoring the ``name`` argument when ``name=None`` is explicitly passed (:issue:`44212`)
754754
- Bug in :meth:`Series.replace` and :meth:`DataFrame.replace` with ``value=None`` and ExtensionDtypes (:issue:`44270`)
755755
- Bug in :meth:`FloatingArray.equals` failing to consider two arrays equal if they contain ``np.nan`` values (:issue:`44382`)
756+
- Bug in :meth:`DataFrame.shift` with ``axis=1`` and ``ExtensionDtype`` columns incorrectly raising when an incompatible ``fill_value`` is passed (:issue:`44564`)
756757
- Bug in :meth:`DataFrame.diff` when passing a NumPy integer object instead of an ``int`` object (:issue:`44572`)
757758
-
758759

pandas/core/internals/managers.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_1d_only_ea_dtype,
3737
is_dtype_equal,
3838
is_list_like,
39+
needs_i8_conversion,
3940
)
4041
from pandas.core.dtypes.dtypes import ExtensionDtype
4142
from pandas.core.dtypes.generic import (
@@ -362,7 +363,28 @@ def shift(self: T, periods: int, axis: int, fill_value) -> T:
362363
if fill_value is lib.no_default:
363364
fill_value = None
364365

365-
if axis == 0 and self.ndim == 2 and self.nblocks > 1:
366+
if (
367+
axis == 0
368+
and self.ndim == 2
369+
and (
370+
self.nblocks > 1
371+
or (
372+
# If we only have one block and we know that we can't
373+
# keep the same dtype (i.e. the _can_hold_element check)
374+
# then we can go through the reindex_indexer path
375+
# (and avoid casting logic in the Block method).
376+
# The exception to this (until 2.0) is datetimelike
377+
# dtypes with integers, which cast.
378+
not self.blocks[0]._can_hold_element(fill_value)
379+
# TODO(2.0): remove special case for integer-with-datetimelike
380+
# once deprecation is enforced
381+
and not (
382+
lib.is_integer(fill_value)
383+
and needs_i8_conversion(self.blocks[0].dtype)
384+
)
385+
)
386+
)
387+
):
366388
# GH#35488 we need to watch out for multi-block cases
367389
# We only get here with fill_value not-lib.no_default
368390
ncols = self.shape[0]

pandas/tests/frame/methods/test_shift.py

+77
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,83 @@ def test_shift_dt64values_int_fill_deprecated(self):
331331
expected = DataFrame({"A": [pd.Timestamp(0), pd.Timestamp(0)], "B": df2["A"]})
332332
tm.assert_frame_equal(result, expected)
333333

334+
# same thing but not consolidated
335+
# This isn't great that we get different behavior, but
336+
# that will go away when the deprecation is enforced
337+
df3 = DataFrame({"A": ser})
338+
df3["B"] = ser
339+
assert len(df3._mgr.arrays) == 2
340+
result = df3.shift(1, axis=1, fill_value=0)
341+
expected = DataFrame({"A": [0, 0], "B": df2["A"]})
342+
tm.assert_frame_equal(result, expected)
343+
344+
@pytest.mark.parametrize(
345+
"as_cat",
346+
[
347+
pytest.param(
348+
True,
349+
marks=pytest.mark.xfail(
350+
reason="_can_hold_element incorrectly always returns True"
351+
),
352+
),
353+
False,
354+
],
355+
)
356+
@pytest.mark.parametrize(
357+
"vals",
358+
[
359+
date_range("2020-01-01", periods=2),
360+
date_range("2020-01-01", periods=2, tz="US/Pacific"),
361+
pd.period_range("2020-01-01", periods=2, freq="D"),
362+
pd.timedelta_range("2020 Days", periods=2, freq="D"),
363+
pd.interval_range(0, 3, periods=2),
364+
pytest.param(
365+
pd.array([1, 2], dtype="Int64"),
366+
marks=pytest.mark.xfail(
367+
reason="_can_hold_element incorrectly always returns True"
368+
),
369+
),
370+
pytest.param(
371+
pd.array([1, 2], dtype="Float32"),
372+
marks=pytest.mark.xfail(
373+
reason="_can_hold_element incorrectly always returns True"
374+
),
375+
),
376+
],
377+
ids=lambda x: str(x.dtype),
378+
)
379+
def test_shift_dt64values_axis1_invalid_fill(
380+
self, vals, as_cat, using_array_manager, request
381+
):
382+
# GH#44564
383+
if using_array_manager:
384+
mark = pytest.mark.xfail(raises=NotImplementedError)
385+
request.node.add_marker(mark)
386+
387+
ser = Series(vals)
388+
if as_cat:
389+
ser = ser.astype("category")
390+
391+
df = DataFrame({"A": ser})
392+
result = df.shift(-1, axis=1, fill_value="foo")
393+
expected = DataFrame({"A": ["foo", "foo"]})
394+
tm.assert_frame_equal(result, expected)
395+
396+
# same thing but multiple blocks
397+
df2 = DataFrame({"A": ser, "B": ser})
398+
df2._consolidate_inplace()
399+
400+
result = df2.shift(-1, axis=1, fill_value="foo")
401+
expected = DataFrame({"A": df2["B"], "B": ["foo", "foo"]})
402+
tm.assert_frame_equal(result, expected)
403+
404+
# same thing but not consolidated
405+
df3 = DataFrame({"A": ser})
406+
df3["B"] = ser
407+
assert len(df3._mgr.arrays) == 2
408+
result = df3.shift(-1, axis=1, fill_value="foo")
409+
tm.assert_frame_equal(result, expected)
410+
334411
def test_shift_axis1_categorical_columns(self):
335412
# GH#38434
336413
ci = CategoricalIndex(["a", "b", "c"])

0 commit comments

Comments
 (0)