Skip to content

Commit bff0afd

Browse files
REF: move logic of 'block manager axis' into the BlockManager (#40075)
1 parent 6c0f952 commit bff0afd

File tree

4 files changed

+22
-19
lines changed

4 files changed

+22
-19
lines changed

pandas/core/frame.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -7843,12 +7843,11 @@ def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame:
78437843
raise ValueError("periods must be an integer")
78447844
periods = int(periods)
78457845

7846-
bm_axis = self._get_block_manager_axis(axis)
7847-
7848-
if bm_axis == 0 and periods != 0:
7846+
axis = self._get_axis_number(axis)
7847+
if axis == 1 and periods != 0:
78497848
return self - self.shift(periods, axis=axis)
78507849

7851-
new_data = self._mgr.diff(n=periods, axis=bm_axis)
7850+
new_data = self._mgr.diff(n=periods, axis=axis)
78527851
return self._constructor(new_data).__finalize__(self, "diff")
78537852

78547853
# ----------------------------------------------------------------------

pandas/core/generic.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -8968,8 +8968,6 @@ def _where(
89688968
self._info_axis, axis=self._info_axis_number, copy=False
89698969
)
89708970

8971-
block_axis = self._get_block_manager_axis(axis)
8972-
89738971
if inplace:
89748972
# we may have different type blocks come out of putmask, so
89758973
# reconstruct the block manager
@@ -8985,7 +8983,7 @@ def _where(
89858983
cond=cond,
89868984
align=align,
89878985
errors=errors,
8988-
axis=block_axis,
8986+
axis=axis,
89898987
)
89908988
result = self._constructor(new_data)
89918989
return result.__finalize__(self)
@@ -9296,9 +9294,9 @@ def shift(
92969294

92979295
if freq is None:
92989296
# when freq is None, data is shifted, index is not
9299-
block_axis = self._get_block_manager_axis(axis)
9297+
axis = self._get_axis_number(axis)
93009298
new_data = self._mgr.shift(
9301-
periods=periods, axis=block_axis, fill_value=fill_value
9299+
periods=periods, axis=axis, fill_value=fill_value
93029300
)
93039301
return self._constructor(new_data).__finalize__(self, method="shift")
93049302

pandas/core/internals/array_manager.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,7 @@ def apply(
406406

407407
return type(self)(result_arrays, new_axes)
408408

409-
def apply_2d(
410-
self: T,
411-
f,
412-
ignore_failures: bool = False,
413-
**kwargs,
414-
) -> T:
409+
def apply_2d(self: T, f, ignore_failures: bool = False, **kwargs) -> T:
415410
"""
416411
Variant of `apply`, but where the function should not be applied to
417412
each column independently, but to the full data as a 2D array.
@@ -430,7 +425,10 @@ def apply_2d(
430425

431426
return type(self)(result_arrays, new_axes)
432427

433-
def apply_with_block(self: T, f, align_keys=None, **kwargs) -> T:
428+
def apply_with_block(self: T, f, align_keys=None, swap_axis=True, **kwargs) -> T:
429+
# switch axis to follow BlockManager logic
430+
if swap_axis and "axis" in kwargs and self.ndim == 2:
431+
kwargs["axis"] = 1 if kwargs["axis"] == 0 else 0
434432

435433
align_keys = align_keys or []
436434
aligned_args = {k: kwargs[k] for k in align_keys}
@@ -542,7 +540,6 @@ def putmask(self, mask, new, align: bool = True):
542540
)
543541

544542
def diff(self, n: int, axis: int) -> ArrayManager:
545-
axis = self._normalize_axis(axis)
546543
if axis == 1:
547544
# DataFrame only calls this for n=0, in which case performing it
548545
# with axis=0 is equivalent
@@ -551,13 +548,13 @@ def diff(self, n: int, axis: int) -> ArrayManager:
551548
return self.apply(algos.diff, n=n, axis=axis)
552549

553550
def interpolate(self, **kwargs) -> ArrayManager:
554-
return self.apply_with_block("interpolate", **kwargs)
551+
return self.apply_with_block("interpolate", swap_axis=False, **kwargs)
555552

556553
def shift(self, periods: int, axis: int, fill_value) -> ArrayManager:
557554
if fill_value is lib.no_default:
558555
fill_value = None
559556

560-
if axis == 0 and self.ndim == 2:
557+
if axis == 1 and self.ndim == 2:
561558
# TODO column-wise shift
562559
raise NotImplementedError
563560

pandas/core/internals/managers.py

+9
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,12 @@ def shape(self) -> Shape:
237237
def ndim(self) -> int:
238238
return len(self.axes)
239239

240+
def _normalize_axis(self, axis):
241+
# switch axis to follow BlockManager logic
242+
if self.ndim == 2:
243+
axis = 1 if axis == 0 else 0
244+
return axis
245+
240246
def set_axis(
241247
self, axis: int, new_labels: Index, verify_integrity: bool = True
242248
) -> None:
@@ -560,6 +566,7 @@ def isna(self, func) -> BlockManager:
560566
return self.apply("apply", func=func)
561567

562568
def where(self, other, cond, align: bool, errors: str, axis: int) -> BlockManager:
569+
axis = self._normalize_axis(axis)
563570
if align:
564571
align_keys = ["other", "cond"]
565572
else:
@@ -594,12 +601,14 @@ def putmask(self, mask, new, align: bool = True):
594601
)
595602

596603
def diff(self, n: int, axis: int) -> BlockManager:
604+
axis = self._normalize_axis(axis)
597605
return self.apply("diff", n=n, axis=axis)
598606

599607
def interpolate(self, **kwargs) -> BlockManager:
600608
return self.apply("interpolate", **kwargs)
601609

602610
def shift(self, periods: int, axis: int, fill_value) -> BlockManager:
611+
axis = self._normalize_axis(axis)
603612
if fill_value is lib.no_default:
604613
fill_value = None
605614

0 commit comments

Comments
 (0)