Skip to content

Commit 95b0e14

Browse files
TST: expand tests for ExtensionArray setitem with nullable arrays (#31741)
1 parent 4ac1e5f commit 95b0e14

File tree

2 files changed

+133
-12
lines changed

2 files changed

+133
-12
lines changed

pandas/tests/extension/base/setitem.py

+87-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
import pandas as pd
7-
from pandas.core.arrays.numpy_ import PandasDtype
7+
import pandas._testing as tm
88

99
from .base import BaseExtensionTests
1010

@@ -93,6 +93,92 @@ def test_setitem_iloc_scalar_multiple_homogoneous(self, data):
9393
df.iloc[10, 1] = data[1]
9494
assert df.loc[10, "B"] == data[1]
9595

96+
@pytest.mark.parametrize(
97+
"mask",
98+
[
99+
np.array([True, True, True, False, False]),
100+
pd.array([True, True, True, False, False], dtype="boolean"),
101+
],
102+
ids=["numpy-array", "boolean-array"],
103+
)
104+
def test_setitem_mask(self, data, mask, box_in_series):
105+
arr = data[:5].copy()
106+
expected = arr.take([0, 0, 0, 3, 4])
107+
if box_in_series:
108+
arr = pd.Series(arr)
109+
expected = pd.Series(expected)
110+
arr[mask] = data[0]
111+
self.assert_equal(expected, arr)
112+
113+
def test_setitem_mask_raises(self, data, box_in_series):
114+
# wrong length
115+
mask = np.array([True, False])
116+
117+
if box_in_series:
118+
data = pd.Series(data)
119+
120+
with pytest.raises(IndexError, match="wrong length"):
121+
data[mask] = data[0]
122+
123+
mask = pd.array(mask, dtype="boolean")
124+
with pytest.raises(IndexError, match="wrong length"):
125+
data[mask] = data[0]
126+
127+
def test_setitem_mask_boolean_array_raises(self, data, box_in_series):
128+
# missing values in mask
129+
mask = pd.array(np.zeros(data.shape, dtype="bool"), dtype="boolean")
130+
mask[:2] = pd.NA
131+
132+
if box_in_series:
133+
data = pd.Series(data)
134+
135+
msg = (
136+
"Cannot mask with a boolean indexer containing NA values|"
137+
"cannot mask with array containing NA / NaN values"
138+
)
139+
with pytest.raises(ValueError, match=msg):
140+
data[mask] = data[0]
141+
142+
@pytest.mark.parametrize(
143+
"idx",
144+
[[0, 1, 2], pd.array([0, 1, 2], dtype="Int64"), np.array([0, 1, 2])],
145+
ids=["list", "integer-array", "numpy-array"],
146+
)
147+
def test_setitem_integer_array(self, data, idx, box_in_series):
148+
arr = data[:5].copy()
149+
expected = data.take([0, 0, 0, 3, 4])
150+
151+
if box_in_series:
152+
arr = pd.Series(arr)
153+
expected = pd.Series(expected)
154+
155+
arr[idx] = arr[0]
156+
self.assert_equal(arr, expected)
157+
158+
@pytest.mark.parametrize(
159+
"idx, box_in_series",
160+
[
161+
([0, 1, 2, pd.NA], False),
162+
pytest.param(
163+
[0, 1, 2, pd.NA], True, marks=pytest.mark.xfail(reason="GH-31948")
164+
),
165+
(pd.array([0, 1, 2, pd.NA], dtype="Int64"), False),
166+
(pd.array([0, 1, 2, pd.NA], dtype="Int64"), False),
167+
],
168+
ids=["list-False", "list-True", "integer-array-False", "integer-array-True"],
169+
)
170+
def test_setitem_integer_with_missing_raises(self, data, idx, box_in_series):
171+
arr = data.copy()
172+
173+
# TODO(xfail) this raises KeyError about labels not found (it tries label-based)
174+
# for list of labels with Series
175+
if box_in_series:
176+
arr = pd.Series(data, index=[tm.rands(4) for _ in range(len(data))])
177+
178+
msg = "Cannot index with an integer indexer containing NA values"
179+
with pytest.raises(ValueError, match=msg):
180+
arr[idx] = arr[0]
181+
96182
@pytest.mark.parametrize("as_callable", [True, False])
97183
@pytest.mark.parametrize("setter", ["loc", None])
98184
def test_setitem_mask_aligned(self, data, as_callable, setter):
@@ -219,14 +305,3 @@ def test_setitem_preserves_views(self, data):
219305
data[0] = data[1]
220306
assert view1[0] == data[1]
221307
assert view2[0] == data[1]
222-
223-
def test_setitem_nullable_mask(self, data):
224-
# GH 31446
225-
# TODO: there is some issue with PandasArray, therefore,
226-
# TODO: skip the setitem test for now, and fix it later
227-
if data.dtype != PandasDtype("object"):
228-
arr = data[:5]
229-
expected = data.take([0, 0, 0, 3, 4])
230-
mask = pd.array([True, True, True, False, False])
231-
arr[mask] = data[0]
232-
self.assert_extension_array_equal(expected, arr)

pandas/tests/extension/test_numpy.py

+46
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,52 @@ def test_setitem_scalar_key_sequence_raise(self, data):
396396
# Failed: DID NOT RAISE <class 'ValueError'>
397397
super().test_setitem_scalar_key_sequence_raise(data)
398398

399+
# TODO: there is some issue with PandasArray, therefore,
400+
# skip the setitem test for now, and fix it later (GH 31446)
401+
402+
@skip_nested
403+
@pytest.mark.parametrize(
404+
"mask",
405+
[
406+
np.array([True, True, True, False, False]),
407+
pd.array([True, True, True, False, False], dtype="boolean"),
408+
],
409+
ids=["numpy-array", "boolean-array"],
410+
)
411+
def test_setitem_mask(self, data, mask, box_in_series):
412+
super().test_setitem_mask(data, mask, box_in_series)
413+
414+
@skip_nested
415+
def test_setitem_mask_raises(self, data, box_in_series):
416+
super().test_setitem_mask_raises(data, box_in_series)
417+
418+
@skip_nested
419+
def test_setitem_mask_boolean_array_raises(self, data, box_in_series):
420+
super().test_setitem_mask_boolean_array_raises(data, box_in_series)
421+
422+
@skip_nested
423+
@pytest.mark.parametrize(
424+
"idx",
425+
[[0, 1, 2], pd.array([0, 1, 2], dtype="Int64"), np.array([0, 1, 2])],
426+
ids=["list", "integer-array", "numpy-array"],
427+
)
428+
def test_setitem_integer_array(self, data, idx, box_in_series):
429+
super().test_setitem_integer_array(data, idx, box_in_series)
430+
431+
@skip_nested
432+
@pytest.mark.parametrize(
433+
"idx, box_in_series",
434+
[
435+
([0, 1, 2, pd.NA], False),
436+
pytest.param([0, 1, 2, pd.NA], True, marks=pytest.mark.xfail),
437+
(pd.array([0, 1, 2, pd.NA], dtype="Int64"), False),
438+
(pd.array([0, 1, 2, pd.NA], dtype="Int64"), False),
439+
],
440+
ids=["list-False", "list-True", "integer-array-False", "integer-array-True"],
441+
)
442+
def test_setitem_integer_with_missing_raises(self, data, idx, box_in_series):
443+
super().test_setitem_integer_with_missing_raises(data, idx, box_in_series)
444+
399445
@skip_nested
400446
def test_setitem_slice(self, data, box_in_series):
401447
super().test_setitem_slice(data, box_in_series)

0 commit comments

Comments
 (0)