Skip to content

Commit 4082c46

Browse files
BUG: Array.__setitem__ failing with nullable boolean mask (#31484) (#31562)
1 parent 2181824 commit 4082c46

File tree

11 files changed

+45
-0
lines changed

11 files changed

+45
-0
lines changed

doc/source/whatsnew/v1.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ Indexing
7171

7272
-
7373
-
74+
- Bug where assigning to a :class:`Series` using a IntegerArray / BooleanArray as a mask would raise ``TypeError`` (:issue:`31446`)
7475

7576
Missing
7677
^^^^^^^

pandas/core/arrays/boolean.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pandas.core.dtypes.missing import isna, notna
2727

2828
from pandas.core import nanops, ops
29+
from pandas.core.indexers import check_array_indexer
2930

3031
from .masked import BaseMaskedArray
3132

@@ -369,6 +370,7 @@ def __setitem__(self, key, value):
369370
value = value[0]
370371
mask = mask[0]
371372

373+
key = check_array_indexer(self, key)
372374
self._data[key] = value
373375
self._mask[key] = mask
374376

pandas/core/arrays/categorical.py

+2
Original file line numberDiff line numberDiff line change
@@ -2073,6 +2073,8 @@ def __setitem__(self, key, value):
20732073

20742074
lindexer = self.categories.get_indexer(rvalue)
20752075
lindexer = self._maybe_coerce_indexer(lindexer)
2076+
2077+
key = check_array_indexer(self, key)
20762078
self._codes[key] = lindexer
20772079

20782080
def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]:

pandas/core/arrays/datetimelike.py

+2
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,8 @@ def __setitem__(
601601
f"or array of those. Got '{type(value).__name__}' instead."
602602
)
603603
raise TypeError(msg)
604+
605+
key = check_array_indexer(self, key)
604606
self._data[key] = value
605607
self._maybe_clear_freq()
606608

pandas/core/arrays/integer.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pandas.core.dtypes.missing import isna
2626

2727
from pandas.core import nanops, ops
28+
from pandas.core.indexers import check_array_indexer
2829
from pandas.core.ops import invalid_comparison
2930
from pandas.core.ops.common import unpack_zerodim_and_defer
3031
from pandas.core.tools.numeric import to_numeric
@@ -414,6 +415,7 @@ def __setitem__(self, key, value):
414415
value = value[0]
415416
mask = mask[0]
416417

418+
key = check_array_indexer(self, key)
417419
self._data[key] = value
418420
self._mask[key] = mask
419421

pandas/core/arrays/interval.py

+1
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def __setitem__(self, key, value):
541541
msg = f"'value' should be an interval type, got {type(value)} instead."
542542
raise TypeError(msg)
543543

544+
key = check_array_indexer(self, key)
544545
# Need to ensure that left and right are updated atomically, so we're
545546
# forced to copy, update the copy, and swap in the new values.
546547
left = self.left.copy(deep=True)

pandas/core/arrays/numpy_.py

+1
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def __getitem__(self, item):
243243
def __setitem__(self, key, value):
244244
value = extract_array(value, extract_numpy=True)
245245

246+
key = check_array_indexer(self, key)
246247
scalar_key = lib.is_scalar(key)
247248
scalar_value = lib.is_scalar(value)
248249

pandas/core/arrays/string_.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pandas.core import ops
1616
from pandas.core.arrays import PandasArray
1717
from pandas.core.construction import extract_array
18+
from pandas.core.indexers import check_array_indexer
1819
from pandas.core.missing import isna
1920

2021

@@ -224,6 +225,7 @@ def __setitem__(self, key, value):
224225
# extract_array doesn't extract PandasArray subclasses
225226
value = value._ndarray
226227

228+
key = check_array_indexer(self, key)
227229
scalar_key = lib.is_scalar(key)
228230
scalar_value = lib.is_scalar(value)
229231
if scalar_key and not scalar_value:

pandas/tests/arrays/test_integer.py

+17
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,23 @@ def test_cut(bins, right, include_lowest):
10721072
tm.assert_categorical_equal(result, expected)
10731073

10741074

1075+
def test_array_setitem_nullable_boolean_mask():
1076+
# GH 31446
1077+
ser = pd.Series([1, 2], dtype="Int64")
1078+
result = ser.where(ser > 1)
1079+
expected = pd.Series([pd.NA, 2], dtype="Int64")
1080+
tm.assert_series_equal(result, expected)
1081+
1082+
1083+
def test_array_setitem():
1084+
# GH 31446
1085+
arr = pd.Series([1, 2], dtype="Int64").array
1086+
arr[arr > 1] = 1
1087+
1088+
expected = pd.array([1, 1], dtype="Int64")
1089+
tm.assert_extension_array_equal(arr, expected)
1090+
1091+
10751092
# TODO(jreback) - these need testing / are broken
10761093

10771094
# shift

pandas/tests/extension/base/setitem.py

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import pandas as pd
7+
from pandas.core.arrays.numpy_ import PandasDtype
78

89
from .base import BaseExtensionTests
910

@@ -195,3 +196,14 @@ def test_setitem_preserves_views(self, data):
195196
data[0] = data[1]
196197
assert view1[0] == data[1]
197198
assert view2[0] == data[1]
199+
200+
def test_setitem_nullable_mask(self, data):
201+
# GH 31446
202+
# TODO: there is some issue with PandasArray, therefore,
203+
# TODO: skip the setitem test for now, and fix it later
204+
if data.dtype != PandasDtype("object"):
205+
arr = data[:5]
206+
expected = data.take([0, 0, 0, 3, 4])
207+
mask = pd.array([True, True, True, False, False])
208+
arr[mask] = data[0]
209+
self.assert_extension_array_equal(expected, arr)

pandas/tests/extension/decimal/array.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pandas as pd
1111
from pandas.api.extensions import no_default, register_extension_dtype
1212
from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin
13+
from pandas.core.indexers import check_array_indexer
1314

1415

1516
@register_extension_dtype
@@ -144,6 +145,8 @@ def __setitem__(self, key, value):
144145
value = [decimal.Decimal(v) for v in value]
145146
else:
146147
value = decimal.Decimal(value)
148+
149+
key = check_array_indexer(self, key)
147150
self._data[key] = value
148151

149152
def __len__(self) -> int:

0 commit comments

Comments
 (0)