Skip to content

Commit f444213

Browse files
authored
TYP: ensure Block.putmask, Block.where get arrays, not Series/DataFrame (#32962)
1 parent e3f1cf1 commit f444213

File tree

3 files changed

+41
-24
lines changed

3 files changed

+41
-24
lines changed

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8609,7 +8609,7 @@ def _where(
86098609
# GH 2745 / GH 4192
86108610
# treat like a scalar
86118611
if len(other) == 1:
8612-
other = np.array(other[0])
8612+
other = other[0]
86138613

86148614
# GH 3235
86158615
# match True cond to other

pandas/core/internals/blocks.py

+36-22
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from pandas.core.dtypes.generic import (
5656
ABCDataFrame,
5757
ABCExtensionArray,
58+
ABCIndexClass,
5859
ABCPandasArray,
5960
ABCSeries,
6061
)
@@ -913,7 +914,7 @@ def putmask(
913914
914915
Parameters
915916
----------
916-
mask : the condition to respect
917+
mask : np.ndarray[bool], SparseArray[bool], or BooleanArray
917918
new : a ndarray/object
918919
inplace : bool, default False
919920
Perform inplace modification.
@@ -925,10 +926,10 @@ def putmask(
925926
-------
926927
List[Block]
927928
"""
928-
new_values = self.values if inplace else self.values.copy()
929+
mask = _extract_bool_array(mask)
930+
assert not isinstance(new, (ABCIndexClass, ABCSeries, ABCDataFrame))
929931

930-
new = getattr(new, "values", new)
931-
mask = getattr(mask, "values", mask)
932+
new_values = self.values if inplace else self.values.copy()
932933

933934
# if we are passed a scalar None, convert it here
934935
if not is_list_like(new) and isna(new) and not self.is_object:
@@ -1308,18 +1309,21 @@ def where(
13081309
Parameters
13091310
----------
13101311
other : a ndarray/object
1311-
cond : the condition to respect
1312+
cond : np.ndarray[bool], SparseArray[bool], or BooleanArray
13121313
errors : str, {'raise', 'ignore'}, default 'raise'
13131314
- ``raise`` : allow exceptions to be raised
13141315
- ``ignore`` : suppress exceptions. On error return original object
13151316
axis : int, default 0
13161317
13171318
Returns
13181319
-------
1319-
a new block(s), the result of the func
1320+
List[Block]
13201321
"""
13211322
import pandas.core.computation.expressions as expressions
13221323

1324+
cond = _extract_bool_array(cond)
1325+
assert not isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame))
1326+
13231327
assert errors in ["raise", "ignore"]
13241328
transpose = self.ndim == 2
13251329

@@ -1328,9 +1332,6 @@ def where(
13281332
if transpose:
13291333
values = values.T
13301334

1331-
other = getattr(other, "_values", getattr(other, "values", other))
1332-
cond = getattr(cond, "values", cond)
1333-
13341335
# If the default broadcasting would go in the wrong direction, then
13351336
# explicitly reshape other instead
13361337
if getattr(other, "ndim", 0) >= 1:
@@ -1628,9 +1629,9 @@ def putmask(
16281629
"""
16291630
inplace = validate_bool_kwarg(inplace, "inplace")
16301631

1631-
# use block's copy logic.
1632-
# .values may be an Index which does shallow copy by default
1633-
new_values = self.values if inplace else self.copy().values
1632+
mask = _extract_bool_array(mask)
1633+
1634+
new_values = self.values if inplace else self.values.copy()
16341635

16351636
if isinstance(new, np.ndarray) and len(new) == len(mask):
16361637
new = new[mask]
@@ -1859,19 +1860,19 @@ def shift(
18591860
def where(
18601861
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0,
18611862
) -> List["Block"]:
1862-
if isinstance(other, ABCDataFrame):
1863-
# ExtensionArrays are 1-D, so if we get here then
1864-
# `other` should be a DataFrame with a single column.
1865-
assert other.shape[1] == 1
1866-
other = other.iloc[:, 0]
18671863

1868-
other = extract_array(other, extract_numpy=True)
1864+
cond = _extract_bool_array(cond)
1865+
assert not isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame))
18691866

1870-
if isinstance(cond, ABCDataFrame):
1871-
assert cond.shape[1] == 1
1872-
cond = cond.iloc[:, 0]
1867+
if isinstance(other, np.ndarray) and other.ndim == 2:
1868+
# TODO(EA2D): unnecessary with 2D EAs
1869+
assert other.shape[1] == 1
1870+
other = other[:, 0]
18731871

1874-
cond = extract_array(cond, extract_numpy=True)
1872+
if isinstance(cond, np.ndarray) and cond.ndim == 2:
1873+
# TODO(EA2D): unnecessary with 2D EAs
1874+
assert cond.shape[1] == 1
1875+
cond = cond[:, 0]
18751876

18761877
if lib.is_scalar(other) and isna(other):
18771878
# The default `other` for Series / Frame is np.nan
@@ -3113,3 +3114,16 @@ def _putmask_preserve(nv, n):
31133114
v = v.astype(dtype)
31143115

31153116
return _putmask_preserve(v, n)
3117+
3118+
3119+
def _extract_bool_array(mask: ArrayLike) -> np.ndarray:
3120+
"""
3121+
If we have a SparseArray or BooleanArray, convert it to ndarray[bool].
3122+
"""
3123+
if isinstance(mask, ExtensionArray):
3124+
# We could have BooleanArray, Sparse[bool], ...
3125+
mask = np.asarray(mask, dtype=np.bool_)
3126+
3127+
assert isinstance(mask, np.ndarray), type(mask)
3128+
assert mask.dtype == bool, mask.dtype
3129+
return mask

pandas/core/internals/managers.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import pandas.core.algorithms as algos
3434
from pandas.core.arrays.sparse import SparseDtype
3535
from pandas.core.base import PandasObject
36+
from pandas.core.construction import extract_array
3637
from pandas.core.indexers import maybe_convert_indices
3738
from pandas.core.indexes.api import Index, ensure_index
3839
from pandas.core.internals.blocks import (
@@ -426,7 +427,7 @@ def apply(self: T, f, filter=None, align_keys=None, **kwargs) -> T:
426427

427428
for k, obj in aligned_args.items():
428429
axis = obj._info_axis_number
429-
kwargs[k] = obj.reindex(b_items, axis=axis, copy=align_copy)
430+
kwargs[k] = obj.reindex(b_items, axis=axis, copy=align_copy)._values
430431

431432
if callable(f):
432433
applied = b.apply(f, **kwargs)
@@ -552,6 +553,7 @@ def where(self, **kwargs) -> "BlockManager":
552553
align_keys = ["other", "cond"]
553554
else:
554555
align_keys = ["cond"]
556+
kwargs["other"] = extract_array(kwargs["other"], extract_numpy=True)
555557

556558
return self.apply("where", align_keys=align_keys, **kwargs)
557559

@@ -567,6 +569,7 @@ def putmask(
567569
align_keys = ["new", "mask"]
568570
else:
569571
align_keys = ["mask"]
572+
new = extract_array(new, extract_numpy=True)
570573

571574
return self.apply(
572575
"putmask",

0 commit comments

Comments
 (0)