Skip to content

Commit 1e4e04b

Browse files
TomAugspurgerjreback
authored andcommitted
ENH: ExtensionArray.setitem (#19907)
1 parent 804101c commit 1e4e04b

File tree

8 files changed

+280
-24
lines changed

8 files changed

+280
-24
lines changed

pandas/core/frame.py

+4
Original file line numberDiff line numberDiff line change
@@ -3331,7 +3331,11 @@ def reindexer(value):
33313331
value = reindexer(value).T
33323332

33333333
elif isinstance(value, ExtensionArray):
3334+
from pandas.core.series import _sanitize_index
3335+
# Explicitly copy here, instead of in _sanitize_index,
3336+
# as sanitize_index won't copy an EA, even with copy=True
33343337
value = value.copy()
3338+
value = _sanitize_index(value, self.index, copy=False)
33353339

33363340
elif isinstance(value, Index) or is_sequence(value):
33373341
from pandas.core.series import _sanitize_index

pandas/core/indexing.py

+43
Original file line numberDiff line numberDiff line change
@@ -2310,6 +2310,49 @@ def check_bool_indexer(ax, key):
23102310
return result
23112311

23122312

2313+
def check_setitem_lengths(indexer, value, values):
2314+
"""Validate that value and indexer are the same length.
2315+
2316+
An special-case is allowed for when the indexer is a boolean array
2317+
and the number of true values equals the length of ``value``. In
2318+
this case, no exception is raised.
2319+
2320+
Parameters
2321+
----------
2322+
indexer : sequence
2323+
The key for the setitem
2324+
value : array-like
2325+
The value for the setitem
2326+
values : array-like
2327+
The values being set into
2328+
2329+
Returns
2330+
-------
2331+
None
2332+
2333+
Raises
2334+
------
2335+
ValueError
2336+
When the indexer is an ndarray or list and the lengths don't
2337+
match.
2338+
"""
2339+
# boolean with truth values == len of the value is ok too
2340+
if isinstance(indexer, (np.ndarray, list)):
2341+
if is_list_like(value) and len(indexer) != len(value):
2342+
if not (isinstance(indexer, np.ndarray) and
2343+
indexer.dtype == np.bool_ and
2344+
len(indexer[indexer]) == len(value)):
2345+
raise ValueError("cannot set using a list-like indexer "
2346+
"with a different length than the value")
2347+
# slice
2348+
elif isinstance(indexer, slice):
2349+
2350+
if is_list_like(value) and len(values):
2351+
if len(value) != length_of_indexer(indexer, values):
2352+
raise ValueError("cannot set using a slice indexer with a "
2353+
"different length than the value")
2354+
2355+
23132356
def convert_missing_indexer(indexer):
23142357
""" reverse convert a missing indexer, which is a dict
23152358
return the scalar indexer and a boolean indicating if we converted

pandas/core/internals.py

+53-23
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
import pandas.core.algorithms as algos
6767

6868
from pandas.core.index import Index, MultiIndex, _ensure_index
69-
from pandas.core.indexing import maybe_convert_indices, length_of_indexer
69+
from pandas.core.indexing import maybe_convert_indices, check_setitem_lengths
7070
from pandas.core.arrays import Categorical
7171
from pandas.core.indexes.datetimes import DatetimeIndex
7272
from pandas.core.indexes.timedeltas import TimedeltaIndex
@@ -817,11 +817,24 @@ def _replace_single(self, *args, **kwargs):
817817
return self if kwargs['inplace'] else self.copy()
818818

819819
def setitem(self, indexer, value, mgr=None):
820-
""" set the value inplace; return a new block (of a possibly different
821-
dtype)
820+
"""Set the value inplace, returning a a maybe different typed block.
822821
823-
indexer is a direct slice/positional indexer; value must be a
824-
compatible shape
822+
Parameters
823+
----------
824+
indexer : tuple, list-like, array-like, slice
825+
The subset of self.values to set
826+
value : object
827+
The value being set
828+
mgr : BlockPlacement, optional
829+
830+
Returns
831+
-------
832+
Block
833+
834+
Notes
835+
-----
836+
`indexer` is a direct slice/positional indexer. `value` must
837+
be a compatible shape.
825838
"""
826839
# coerce None values, if appropriate
827840
if value is None:
@@ -876,22 +889,7 @@ def setitem(self, indexer, value, mgr=None):
876889
values = transf(values)
877890

878891
# length checking
879-
# boolean with truth values == len of the value is ok too
880-
if isinstance(indexer, (np.ndarray, list)):
881-
if is_list_like(value) and len(indexer) != len(value):
882-
if not (isinstance(indexer, np.ndarray) and
883-
indexer.dtype == np.bool_ and
884-
len(indexer[indexer]) == len(value)):
885-
raise ValueError("cannot set using a list-like indexer "
886-
"with a different length than the value")
887-
888-
# slice
889-
elif isinstance(indexer, slice):
890-
891-
if is_list_like(value) and len(values):
892-
if len(value) != length_of_indexer(indexer, values):
893-
raise ValueError("cannot set using a slice indexer with a "
894-
"different length than the value")
892+
check_setitem_lengths(indexer, value, values)
895893

896894
def _is_scalar_indexer(indexer):
897895
# return True if we are all scalar indexers
@@ -1900,6 +1898,37 @@ def is_view(self):
19001898
"""Extension arrays are never treated as views."""
19011899
return False
19021900

1901+
def setitem(self, indexer, value, mgr=None):
1902+
"""Set the value inplace, returning a same-typed block.
1903+
1904+
This differs from Block.setitem by not allowing setitem to change
1905+
the dtype of the Block.
1906+
1907+
Parameters
1908+
----------
1909+
indexer : tuple, list-like, array-like, slice
1910+
The subset of self.values to set
1911+
value : object
1912+
The value being set
1913+
mgr : BlockPlacement, optional
1914+
1915+
Returns
1916+
-------
1917+
Block
1918+
1919+
Notes
1920+
-----
1921+
`indexer` is a direct slice/positional indexer. `value` must
1922+
be a compatible shape.
1923+
"""
1924+
if isinstance(indexer, tuple):
1925+
# we are always 1-D
1926+
indexer = indexer[0]
1927+
1928+
check_setitem_lengths(indexer, value, self.values)
1929+
self.values[indexer] = value
1930+
return self
1931+
19031932
def get_values(self, dtype=None):
19041933
# ExtensionArrays must be iterable, so this works.
19051934
values = np.asarray(self.values)
@@ -3519,7 +3548,8 @@ def apply(self, f, axes=None, filter=None, do_integrity_check=False,
35193548
# with a .values attribute.
35203549
aligned_args = dict((k, kwargs[k])
35213550
for k in align_keys
3522-
if hasattr(kwargs[k], 'values'))
3551+
if hasattr(kwargs[k], 'values') and
3552+
not isinstance(kwargs[k], ABCExtensionArray))
35233553

35243554
for b in self.blocks:
35253555
if filter is not None:
@@ -5220,7 +5250,7 @@ def _safe_reshape(arr, new_shape):
52205250
If possible, reshape `arr` to have shape `new_shape`,
52215251
with a couple of exceptions (see gh-13012):
52225252
5223-
1) If `arr` is a Categorical or Index, `arr` will be
5253+
1) If `arr` is a ExtensionArray or Index, `arr` will be
52245254
returned as is.
52255255
2) If `arr` is a Series, the `_values` attribute will
52265256
be reshaped and returned.

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ class TestMyDtype(BaseDtypeTests):
4949
from .methods import BaseMethodsTests # noqa
5050
from .missing import BaseMissingTests # noqa
5151
from .reshaping import BaseReshapingTests # noqa
52+
from .setitem import BaseSetitemTests # noqa
+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import operator
2+
3+
import numpy as np
4+
import pytest
5+
6+
import pandas as pd
7+
import pandas.util.testing as tm
8+
from .base import BaseExtensionTests
9+
10+
11+
class BaseSetitemTests(BaseExtensionTests):
12+
def test_setitem_scalar_series(self, data):
13+
arr = pd.Series(data)
14+
arr[0] = data[1]
15+
assert arr[0] == data[1]
16+
17+
def test_setitem_sequence(self, data):
18+
arr = pd.Series(data)
19+
original = data.copy()
20+
21+
arr[[0, 1]] = [data[1], data[0]]
22+
assert arr[0] == original[1]
23+
assert arr[1] == original[0]
24+
25+
@pytest.mark.parametrize('as_array', [True, False])
26+
def test_setitem_sequence_mismatched_length_raises(self, data, as_array):
27+
ser = pd.Series(data)
28+
value = [data[0]]
29+
if as_array:
30+
value = type(data)(value)
31+
32+
xpr = 'cannot set using a {} indexer with a different length'
33+
with tm.assert_raises_regex(ValueError, xpr.format('list-like')):
34+
ser[[0, 1]] = value
35+
36+
with tm.assert_raises_regex(ValueError, xpr.format('slice')):
37+
ser[slice(3)] = value
38+
39+
def test_setitem_empty_indxer(self, data):
40+
ser = pd.Series(data)
41+
original = ser.copy()
42+
ser[[]] = []
43+
self.assert_series_equal(ser, original)
44+
45+
def test_setitem_sequence_broadcasts(self, data):
46+
arr = pd.Series(data)
47+
48+
arr[[0, 1]] = data[2]
49+
assert arr[0] == data[2]
50+
assert arr[1] == data[2]
51+
52+
@pytest.mark.parametrize('setter', ['loc', 'iloc'])
53+
def test_setitem_scalar(self, data, setter):
54+
arr = pd.Series(data)
55+
setter = getattr(arr, setter)
56+
operator.setitem(setter, 0, data[1])
57+
assert arr[0] == data[1]
58+
59+
def test_setitem_loc_scalar_mixed(self, data):
60+
df = pd.DataFrame({"A": np.arange(len(data)), "B": data})
61+
df.loc[0, 'B'] = data[1]
62+
assert df.loc[0, 'B'] == data[1]
63+
64+
def test_setitem_loc_scalar_single(self, data):
65+
df = pd.DataFrame({"B": data})
66+
df.loc[10, 'B'] = data[1]
67+
assert df.loc[10, 'B'] == data[1]
68+
69+
def test_setitem_loc_scalar_multiple_homogoneous(self, data):
70+
df = pd.DataFrame({"A": data, "B": data})
71+
df.loc[10, 'B'] = data[1]
72+
assert df.loc[10, 'B'] == data[1]
73+
74+
def test_setitem_iloc_scalar_mixed(self, data):
75+
df = pd.DataFrame({"A": np.arange(len(data)), "B": data})
76+
df.iloc[0, 1] = data[1]
77+
assert df.loc[0, 'B'] == data[1]
78+
79+
def test_setitem_iloc_scalar_single(self, data):
80+
df = pd.DataFrame({"B": data})
81+
df.iloc[10, 0] = data[1]
82+
assert df.loc[10, 'B'] == data[1]
83+
84+
def test_setitem_iloc_scalar_multiple_homogoneous(self, data):
85+
df = pd.DataFrame({"A": data, "B": data})
86+
df.iloc[10, 1] = data[1]
87+
assert df.loc[10, 'B'] == data[1]
88+
89+
@pytest.mark.parametrize('as_callable', [True, False])
90+
@pytest.mark.parametrize('setter', ['loc', None])
91+
def test_setitem_mask_aligned(self, data, as_callable, setter):
92+
ser = pd.Series(data)
93+
mask = np.zeros(len(data), dtype=bool)
94+
mask[:2] = True
95+
96+
if as_callable:
97+
mask2 = lambda x: mask
98+
else:
99+
mask2 = mask
100+
101+
if setter:
102+
# loc
103+
target = getattr(ser, setter)
104+
else:
105+
# Series.__setitem__
106+
target = ser
107+
108+
operator.setitem(target, mask2, data[5:7])
109+
110+
ser[mask2] = data[5:7]
111+
assert ser[0] == data[5]
112+
assert ser[1] == data[6]
113+
114+
@pytest.mark.parametrize('setter', ['loc', None])
115+
def test_setitem_mask_broadcast(self, data, setter):
116+
ser = pd.Series(data)
117+
mask = np.zeros(len(data), dtype=bool)
118+
mask[:2] = True
119+
120+
if setter: # loc
121+
target = getattr(ser, setter)
122+
else: # __setitem__
123+
target = ser
124+
125+
operator.setitem(target, mask, data[10])
126+
assert ser[0] == data[10]
127+
assert ser[1] == data[10]
128+
129+
def test_setitem_expand_columns(self, data):
130+
df = pd.DataFrame({"A": data})
131+
result = df.copy()
132+
result['B'] = 1
133+
expected = pd.DataFrame({"A": data, "B": [1] * len(data)})
134+
self.assert_frame_equal(result, expected)
135+
136+
result = df.copy()
137+
result.loc[:, 'B'] = 1
138+
self.assert_frame_equal(result, expected)
139+
140+
# overwrite with new type
141+
result['B'] = data
142+
expected = pd.DataFrame({"A": data, "B": data})
143+
self.assert_frame_equal(result, expected)
144+
145+
def test_setitem_expand_with_extension(self, data):
146+
df = pd.DataFrame({"A": [1] * len(data)})
147+
result = df.copy()
148+
result['B'] = data
149+
expected = pd.DataFrame({"A": [1] * len(data), "B": data})
150+
self.assert_frame_equal(result, expected)
151+
152+
result = df.copy()
153+
result.loc[:, 'B'] = data
154+
self.assert_frame_equal(result, expected)
155+
156+
def test_setitem_frame_invalid_length(self, data):
157+
df = pd.DataFrame({"A": [1] * len(data)})
158+
xpr = "Length of values does not match length of index"
159+
with tm.assert_raises_regex(ValueError, xpr):
160+
df['B'] = data[:5]
161+
162+
@pytest.mark.xfail(reason="GH-20441: setitem on extension types.")
163+
def test_setitem_tuple_index(self, data):
164+
s = pd.Series(data[:2], index=[(0, 0), (0, 1)])
165+
expected = pd.Series(data.take([1, 1]), index=s.index)
166+
s[(0, 1)] = data[1]
167+
self.assert_series_equal(s, expected)

pandas/tests/extension/category/test_categorical.py

+4
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def test_getitem_scalar(self):
8585
pass
8686

8787

88+
class TestSetitem(base.BaseSetitemTests):
89+
pass
90+
91+
8892
class TestMissing(base.BaseMissingTests):
8993

9094
@pytest.mark.skip(reason="Not implemented")

pandas/tests/extension/decimal/array.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def __init__(self, values):
3131
values = np.asarray(values, dtype=object)
3232

3333
self.values = values
34+
# Some aliases for common attribute names to ensure pandas supports
35+
# these
36+
self._items = self._data = self.data = self.values
3437

3538
@classmethod
3639
def _constructor_from_sequence(cls, scalars):
@@ -62,7 +65,7 @@ def __len__(self):
6265
return len(self.values)
6366

6467
def __repr__(self):
65-
return repr(self.values)
68+
return 'DecimalArray({!r})'.format(self.values)
6669

6770
@property
6871
def nbytes(self):

pandas/tests/extension/json/test_json.py

+4
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def test_astype_str(self):
135135
"""
136136

137137

138+
# We intentionally don't run base.BaseSetitemTests because pandas'
139+
# internals has trouble setting sequences of values into scalar positions.
140+
141+
138142
class TestGroupby(base.BaseGroupbyTests):
139143

140144
@unhashable

0 commit comments

Comments
 (0)