Skip to content

Commit 342eb1e

Browse files
committed
REF: pass align_keys to apply (pandas-dev#32846)
1 parent e07d0ed commit 342eb1e

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

pandas/core/internals/managers.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from pandas.core.dtypes.concat import concat_compat
2929
from pandas.core.dtypes.dtypes import ExtensionDtype
30-
from pandas.core.dtypes.generic import ABCExtensionArray, ABCSeries
30+
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
3131
from pandas.core.dtypes.missing import isna
3232

3333
import pandas.core.algorithms as algos
@@ -375,7 +375,7 @@ def reduce(self, func, *args, **kwargs):
375375

376376
return res
377377

378-
def apply(self: T, f, filter=None, **kwargs) -> T:
378+
def apply(self: T, f, filter=None, align_keys=None, **kwargs) -> T:
379379
"""
380380
Iterate over the blocks, collect and create a new BlockManager.
381381
@@ -390,6 +390,7 @@ def apply(self: T, f, filter=None, **kwargs) -> T:
390390
-------
391391
BlockManager
392392
"""
393+
align_keys = align_keys or []
393394
result_blocks = []
394395
# fillna: Series/DataFrame is responsible for making sure value is aligned
395396

@@ -404,28 +405,14 @@ def apply(self: T, f, filter=None, **kwargs) -> T:
404405

405406
self._consolidate_inplace()
406407

408+
align_copy = False
407409
if f == "where":
408410
align_copy = True
409-
if kwargs.get("align", True):
410-
align_keys = ["other", "cond"]
411-
else:
412-
align_keys = ["cond"]
413-
elif f == "putmask":
414-
align_copy = False
415-
if kwargs.get("align", True):
416-
align_keys = ["new", "mask"]
417-
else:
418-
align_keys = ["mask"]
419-
else:
420-
align_keys = []
421411

422-
# TODO(EA): may interfere with ExtensionBlock.setitem for blocks
423-
# with a .values attribute.
424412
aligned_args = {
425413
k: kwargs[k]
426414
for k in align_keys
427-
if not isinstance(kwargs[k], ABCExtensionArray)
428-
and hasattr(kwargs[k], "values")
415+
if isinstance(kwargs[k], (ABCSeries, ABCDataFrame))
429416
}
430417

431418
for b in self.blocks:
@@ -561,13 +548,24 @@ def isna(self, func) -> "BlockManager":
561548
return self.apply("apply", func=func)
562549

563550
def where(self, **kwargs) -> "BlockManager":
564-
return self.apply("where", **kwargs)
551+
if kwargs.pop("align", True):
552+
align_keys = ["other", "cond"]
553+
else:
554+
align_keys = ["cond"]
555+
556+
return self.apply("where", align_keys=align_keys, **kwargs)
565557

566558
def setitem(self, indexer, value) -> "BlockManager":
567559
return self.apply("setitem", indexer=indexer, value=value)
568560

569561
def putmask(self, **kwargs):
570-
return self.apply("putmask", **kwargs)
562+
563+
if kwargs.pop("align", True):
564+
align_keys = ["new", "mask"]
565+
else:
566+
align_keys = ["mask"]
567+
568+
return self.apply("putmask", align_keys=align_keys, **kwargs)
571569

572570
def diff(self, n: int, axis: int) -> "BlockManager":
573571
return self.apply("diff", n=n, axis=axis)

0 commit comments

Comments
 (0)