Skip to content

Commit ae67cb7

Browse files
authored
BUG: frame.shift(axis=1) with ArrayManager (#45644)
1 parent f4c167a commit ae67cb7

File tree

5 files changed

+47
-69
lines changed

5 files changed

+47
-69
lines changed

pandas/core/frame.py

+43
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
)
9494

9595
from pandas.core.dtypes.cast import (
96+
can_hold_element,
9697
construct_1d_arraylike_from_scalar,
9798
construct_2d_arraylike_from_scalar,
9899
find_common_type,
@@ -5366,6 +5367,48 @@ def shift(
53665367

53675368
result.columns = self.columns.copy()
53685369
return result
5370+
elif (
5371+
axis == 1
5372+
and periods != 0
5373+
and fill_value is not lib.no_default
5374+
and ncols > 0
5375+
):
5376+
arrays = self._mgr.arrays
5377+
if len(arrays) > 1 or (
5378+
# If we only have one block and we know that we can't
5379+
# keep the same dtype (i.e. the _can_hold_element check)
5380+
# then we can go through the reindex_indexer path
5381+
# (and avoid casting logic in the Block method).
5382+
# The exception to this (until 2.0) is datetimelike
5383+
# dtypes with integers, which cast.
5384+
not can_hold_element(arrays[0], fill_value)
5385+
# TODO(2.0): remove special case for integer-with-datetimelike
5386+
# once deprecation is enforced
5387+
and not (
5388+
lib.is_integer(fill_value) and needs_i8_conversion(arrays[0].dtype)
5389+
)
5390+
):
5391+
# GH#35488 we need to watch out for multi-block cases
5392+
# We only get here with fill_value not-lib.no_default
5393+
nper = abs(periods)
5394+
nper = min(nper, ncols)
5395+
if periods > 0:
5396+
indexer = np.array(
5397+
[-1] * nper + list(range(ncols - periods)), dtype=np.intp
5398+
)
5399+
else:
5400+
indexer = np.array(
5401+
list(range(nper, ncols)) + [-1] * nper, dtype=np.intp
5402+
)
5403+
mgr = self._mgr.reindex_indexer(
5404+
self.columns,
5405+
indexer,
5406+
axis=0,
5407+
fill_value=fill_value,
5408+
allow_dups=True,
5409+
)
5410+
res_df = self._constructor(mgr)
5411+
return res_df.__finalize__(self, method="shift")
53695412

53705413
return super().shift(
53715414
periods=periods, freq=freq, axis=axis, fill_value=fill_value

pandas/core/internals/managers.py

-45
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
is_1d_only_ea_dtype,
3737
is_dtype_equal,
3838
is_list_like,
39-
needs_i8_conversion,
4039
)
4140
from pandas.core.dtypes.dtypes import ExtensionDtype
4241
from pandas.core.dtypes.generic import (
@@ -366,50 +365,6 @@ def shift(self: T, periods: int, axis: int, fill_value) -> T:
366365
if fill_value is lib.no_default:
367366
fill_value = None
368367

369-
if (
370-
axis == 0
371-
and self.ndim == 2
372-
and (
373-
self.nblocks > 1
374-
or (
375-
# If we only have one block and we know that we can't
376-
# keep the same dtype (i.e. the _can_hold_element check)
377-
# then we can go through the reindex_indexer path
378-
# (and avoid casting logic in the Block method).
379-
# The exception to this (until 2.0) is datetimelike
380-
# dtypes with integers, which cast.
381-
not self.blocks[0]._can_hold_element(fill_value)
382-
# TODO(2.0): remove special case for integer-with-datetimelike
383-
# once deprecation is enforced
384-
and not (
385-
lib.is_integer(fill_value)
386-
and needs_i8_conversion(self.blocks[0].dtype)
387-
)
388-
)
389-
)
390-
):
391-
# GH#35488 we need to watch out for multi-block cases
392-
# We only get here with fill_value not-lib.no_default
393-
ncols = self.shape[0]
394-
nper = abs(periods)
395-
nper = min(nper, ncols)
396-
if periods > 0:
397-
indexer = np.array(
398-
[-1] * nper + list(range(ncols - periods)), dtype=np.intp
399-
)
400-
else:
401-
indexer = np.array(
402-
list(range(nper, ncols)) + [-1] * nper, dtype=np.intp
403-
)
404-
result = self.reindex_indexer(
405-
self.items,
406-
indexer,
407-
axis=0,
408-
fill_value=fill_value,
409-
allow_dups=True,
410-
)
411-
return result
412-
413368
return self.apply("shift", periods=periods, axis=axis, fill_value=fill_value)
414369

415370
def fillna(self: T, value, limit, inplace: bool, downcast) -> T:

pandas/tests/apply/test_str.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -256,19 +256,10 @@ def test_transform_groupby_kernel_series(string_series, op):
256256

257257

258258
@pytest.mark.parametrize("op", frame_transform_kernels)
259-
def test_transform_groupby_kernel_frame(
260-
axis, float_frame, op, using_array_manager, request
261-
):
259+
def test_transform_groupby_kernel_frame(axis, float_frame, op, request):
262260
# TODO(2.0) Remove after pad/backfill deprecation enforced
263261
op = maybe_normalize_deprecated_kernels(op)
264262
# GH 35964
265-
if using_array_manager and op == "pct_change" and axis in (1, "columns"):
266-
# TODO(ArrayManager) shift with axis=1
267-
request.node.add_marker(
268-
pytest.mark.xfail(
269-
reason="shift axis=1 not yet implemented for ArrayManager"
270-
)
271-
)
272263

273264
args = [0.0] if op == "fillna" else []
274265
if axis == 0 or axis == "index":

pandas/tests/frame/methods/test_shift.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -611,14 +611,8 @@ def test_shift_dt64values_int_fill_deprecated(self):
611611
)
612612
# TODO(2.0): remove filtering
613613
@pytest.mark.filterwarnings("ignore:Index.ravel.*:FutureWarning")
614-
def test_shift_dt64values_axis1_invalid_fill(
615-
self, vals, as_cat, using_array_manager, request
616-
):
614+
def test_shift_dt64values_axis1_invalid_fill(self, vals, as_cat, request):
617615
# GH#44564
618-
if using_array_manager:
619-
mark = pytest.mark.xfail(raises=NotImplementedError)
620-
request.node.add_marker(mark)
621-
622616
ser = Series(vals)
623617
if as_cat:
624618
ser = ser.astype("category")
@@ -665,7 +659,6 @@ def test_shift_axis1_categorical_columns(self):
665659
)
666660
tm.assert_frame_equal(result, expected)
667661

668-
@td.skip_array_manager_not_yet_implemented
669662
def test_shift_axis1_many_periods(self):
670663
# GH#44978 periods > len(columns)
671664
df = DataFrame(np.random.rand(5, 3))

pandas/tests/groupby/transform/test_transform.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,9 @@ def test_transform_broadcast(tsframe, ts):
165165
assert_fp_equal(res.xs(idx), agged[idx])
166166

167167

168-
def test_transform_axis_1(request, transformation_func, using_array_manager):
168+
def test_transform_axis_1(request, transformation_func):
169169
# GH 36308
170-
if using_array_manager and transformation_func == "pct_change":
171-
# TODO(ArrayManager) column-wise shift
172-
request.node.add_marker(
173-
pytest.mark.xfail(reason="ArrayManager: shift axis=1 not yet implemented")
174-
)
170+
175171
# TODO(2.0) Remove after pad/backfill deprecation enforced
176172
transformation_func = maybe_normalize_deprecated_kernels(transformation_func)
177173

0 commit comments

Comments
 (0)