Skip to content

Commit 9d676a9

Browse files
authored
REF: remove axis keyword from Manager/Block.shift (#53845)
* REF: remove axis keyword from Manager/Block.shift * mypy fixup * mypy fixup
1 parent 6c75f4f commit 9d676a9

File tree

6 files changed

+27
-39
lines changed

6 files changed

+27
-39
lines changed

pandas/core/arrays/_mixins.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def searchsorted(
238238
return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)
239239

240240
@doc(ExtensionArray.shift)
241-
def shift(self, periods: int = 1, fill_value=None, axis: AxisInt = 0):
241+
def shift(self, periods: int = 1, fill_value=None):
242+
# NB: shift is always along axis=0
243+
axis = 0
242244
fill_value = self._validate_scalar(fill_value)
243245
new_values = shift(self._ndarray, periods, axis, fill_value)
244246

pandas/core/arrays/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,8 @@ def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
851851
If ``periods > len(self)``, then an array of size
852852
len(self) is returned, with all values filled with
853853
``self.dtype.na_value``.
854+
855+
For 2-dimensional ExtensionArrays, we are always shifting along axis=0.
854856
"""
855857
# Note: this implementation assumes that `self.dtype.na_value` can be
856858
# stored in an instance of your ExtensionArray with `self.dtype`.

pandas/core/generic.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -10595,9 +10595,8 @@ def shift(
1059510595
if freq is None:
1059610596
# when freq is None, data is shifted, index is not
1059710597
axis = self._get_axis_number(axis)
10598-
new_data = self._mgr.shift(
10599-
periods=periods, axis=axis, fill_value=fill_value
10600-
)
10598+
assert axis == 0 # axis == 1 cases handled in DataFrame.shift
10599+
new_data = self._mgr.shift(periods=periods, fill_value=fill_value)
1060110600
return self._constructor_from_mgr(
1060210601
new_data, axes=new_data.axes
1060310602
).__finalize__(self, method="shift")

pandas/core/internals/array_manager.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,11 @@ def diff(self, n: int) -> Self:
324324
assert self.ndim == 2 # caller ensures
325325
return self.apply(algos.diff, n=n)
326326

327-
def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
327+
def shift(self, periods: int, fill_value) -> Self:
328328
if fill_value is lib.no_default:
329329
fill_value = None
330330

331-
if axis == 1 and self.ndim == 2:
332-
# TODO column-wise shift
333-
raise NotImplementedError
334-
335-
return self.apply_with_block(
336-
"shift", periods=periods, axis=axis, fill_value=fill_value
337-
)
331+
return self.apply_with_block("shift", periods=periods, fill_value=fill_value)
338332

339333
def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
340334
if copy is None:

pandas/core/internals/blocks.py

+16-24
Original file line numberDiff line numberDiff line change
@@ -1473,12 +1473,11 @@ def diff(self, n: int) -> list[Block]:
14731473
new_values = algos.diff(self.values.T, n, axis=0).T
14741474
return [self.make_block(values=new_values)]
14751475

1476-
def shift(
1477-
self, periods: int, axis: AxisInt = 0, fill_value: Any = None
1478-
) -> list[Block]:
1476+
def shift(self, periods: int, fill_value: Any = None) -> list[Block]:
14791477
"""shift the block by periods, possibly upcast"""
14801478
# convert integer to float if necessary. need to do a lot more than
14811479
# that, handle boolean etc also
1480+
axis = self.ndim - 1
14821481

14831482
# Note: periods is never 0 here, as that is handled at the top of
14841483
# NDFrame.shift. If that ever changes, we can do a check for periods=0
@@ -1500,12 +1499,12 @@ def shift(
15001499
)
15011500
except LossySetitemError:
15021501
nb = self.coerce_to_target_dtype(fill_value)
1503-
return nb.shift(periods, axis=axis, fill_value=fill_value)
1502+
return nb.shift(periods, fill_value=fill_value)
15041503

15051504
else:
15061505
values = cast(np.ndarray, self.values)
15071506
new_values = shift(values, periods, axis, casted)
1508-
return [self.make_block(new_values)]
1507+
return [self.make_block_same_class(new_values)]
15091508

15101509
@final
15111510
def quantile(
@@ -1656,6 +1655,18 @@ class EABackedBlock(Block):
16561655

16571656
values: ExtensionArray
16581657

1658+
def shift(self, periods: int, fill_value: Any = None) -> list[Block]:
1659+
"""
1660+
Shift the block by `periods`.
1661+
1662+
Dispatches to underlying ExtensionArray and re-boxes in an
1663+
ExtensionBlock.
1664+
"""
1665+
# Transpose since EA.shift is always along axis=0, while we want to shift
1666+
# along rows.
1667+
new_values = self.values.T.shift(periods=periods, fill_value=fill_value).T
1668+
return [self.make_block_same_class(new_values)]
1669+
16591670
def setitem(self, indexer, value, using_cow: bool = False):
16601671
"""
16611672
Attempt self.values[indexer] = value, possibly creating a new array.
@@ -2108,18 +2119,6 @@ def slice_block_rows(self, slicer: slice) -> Self:
21082119
new_values = self.values[slicer]
21092120
return type(self)(new_values, self._mgr_locs, ndim=self.ndim, refs=self.refs)
21102121

2111-
def shift(
2112-
self, periods: int, axis: AxisInt = 0, fill_value: Any = None
2113-
) -> list[Block]:
2114-
"""
2115-
Shift the block by `periods`.
2116-
2117-
Dispatches to underlying ExtensionArray and re-boxes in an
2118-
ExtensionBlock.
2119-
"""
2120-
new_values = self.values.shift(periods=periods, fill_value=fill_value)
2121-
return [self.make_block_same_class(new_values)]
2122-
21232122
def _unstack(
21242123
self,
21252124
unstacker,
@@ -2226,13 +2225,6 @@ def is_view(self) -> bool:
22262225
# check the ndarray values of the DatetimeIndex values
22272226
return self.values._ndarray.base is not None
22282227

2229-
def shift(
2230-
self, periods: int, axis: AxisInt = 0, fill_value: Any = None
2231-
) -> list[Block]:
2232-
values = self.values
2233-
new_values = values.shift(periods, fill_value=fill_value, axis=axis)
2234-
return [self.make_block_same_class(new_values)]
2235-
22362228

22372229
def _catch_deprecated_value_error(err: Exception) -> None:
22382230
"""

pandas/core/internals/managers.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,11 @@ def diff(self, n: int) -> Self:
378378
# only reached with self.ndim == 2
379379
return self.apply("diff", n=n)
380380

381-
def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
382-
axis = self._normalize_axis(axis)
381+
def shift(self, periods: int, fill_value) -> Self:
383382
if fill_value is lib.no_default:
384383
fill_value = None
385384

386-
return self.apply("shift", periods=periods, axis=axis, fill_value=fill_value)
385+
return self.apply("shift", periods=periods, fill_value=fill_value)
387386

388387
def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
389388
if copy is None:

0 commit comments

Comments
 (0)