-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: ExtensionArray.setitem #19907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: ExtensionArray.setitem #19907
Changes from all commits
b985ea8
9489f6c
79da90e
7f65c5a
8ff6168
274da13
35ae908
c768709
76f6e86
2d5b08c
f66c093
43dfd7d
5c1d934
66bbe9a
f47ddf2
1e5a14c
abe734d
10a3f19
9a5b8c9
202fae8
3cbe078
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ | |
import pandas.core.algorithms as algos | ||
|
||
from pandas.core.index import Index, MultiIndex, _ensure_index | ||
from pandas.core.indexing import maybe_convert_indices, length_of_indexer | ||
from pandas.core.indexing import maybe_convert_indices, check_setitem_lengths | ||
from pandas.core.arrays import Categorical | ||
from pandas.core.indexes.datetimes import DatetimeIndex | ||
from pandas.core.indexes.timedeltas import TimedeltaIndex | ||
|
@@ -817,11 +817,24 @@ def _replace_single(self, *args, **kwargs): | |
return self if kwargs['inplace'] else self.copy() | ||
|
||
def setitem(self, indexer, value, mgr=None): | ||
""" set the value inplace; return a new block (of a possibly different | ||
dtype) | ||
"""Set the value inplace, returning a a maybe different typed block. | ||
|
||
indexer is a direct slice/positional indexer; value must be a | ||
compatible shape | ||
Parameters | ||
---------- | ||
indexer : tuple, list-like, array-like, slice | ||
The subset of self.values to set | ||
value : object | ||
The value being set | ||
mgr : BlockPlacement, optional | ||
|
||
Returns | ||
------- | ||
Block | ||
|
||
Notes | ||
----- | ||
`indexer` is a direct slice/positional indexer. `value` must | ||
be a compatible shape. | ||
""" | ||
# coerce None values, if appropriate | ||
if value is None: | ||
|
@@ -876,22 +889,7 @@ def setitem(self, indexer, value, mgr=None): | |
values = transf(values) | ||
|
||
# length checking | ||
# boolean with truth values == len of the value is ok too | ||
if isinstance(indexer, (np.ndarray, list)): | ||
if is_list_like(value) and len(indexer) != len(value): | ||
if not (isinstance(indexer, np.ndarray) and | ||
indexer.dtype == np.bool_ and | ||
len(indexer[indexer]) == len(value)): | ||
raise ValueError("cannot set using a list-like indexer " | ||
"with a different length than the value") | ||
|
||
# slice | ||
elif isinstance(indexer, slice): | ||
|
||
if is_list_like(value) and len(values): | ||
if len(value) != length_of_indexer(indexer, values): | ||
raise ValueError("cannot set using a slice indexer with a " | ||
"different length than the value") | ||
check_setitem_lengths(indexer, value, values) | ||
|
||
def _is_scalar_indexer(indexer): | ||
# return True if we are all scalar indexers | ||
|
@@ -1900,6 +1898,37 @@ def is_view(self): | |
"""Extension arrays are never treated as views.""" | ||
return False | ||
|
||
def setitem(self, indexer, value, mgr=None): | ||
"""Set the value inplace, returning a same-typed block. | ||
|
||
This differs from Block.setitem by not allowing setitem to change | ||
the dtype of the Block. | ||
|
||
Parameters | ||
---------- | ||
indexer : tuple, list-like, array-like, slice | ||
The subset of self.values to set | ||
value : object | ||
The value being set | ||
mgr : BlockPlacement, optional | ||
|
||
Returns | ||
------- | ||
Block | ||
|
||
Notes | ||
----- | ||
`indexer` is a direct slice/positional indexer. `value` must | ||
be a compatible shape. | ||
""" | ||
if isinstance(indexer, tuple): | ||
# we are always 1-D | ||
indexer = indexer[0] | ||
|
||
check_setitem_lengths(indexer, value, self.values) | ||
self.values[indexer] = value | ||
return self | ||
|
||
def get_values(self, dtype=None): | ||
# ExtensionArrays must be iterable, so this works. | ||
values = np.asarray(self.values) | ||
|
@@ -3519,7 +3548,8 @@ def apply(self, f, axes=None, filter=None, do_integrity_check=False, | |
# with a .values attribute. | ||
aligned_args = dict((k, kwargs[k]) | ||
for k in align_keys | ||
if hasattr(kwargs[k], 'values')) | ||
if hasattr(kwargs[k], 'values') and | ||
not isinstance(kwargs[k], ABCExtensionArray)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If an ExtensionArray chooses to store it's data as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make a test for this? (eg call the underlying data There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DecimalArray calls it's underlying data There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is pretty special casey here. shouldn't this check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure. I don't really know what could be in kwargs. You think it's only ever Index or Series? Or could it be a dataframe or block? |
||
|
||
for b in self.blocks: | ||
if filter is not None: | ||
|
@@ -5220,7 +5250,7 @@ def _safe_reshape(arr, new_shape): | |
If possible, reshape `arr` to have shape `new_shape`, | ||
with a couple of exceptions (see gh-13012): | ||
|
||
1) If `arr` is a Categorical or Index, `arr` will be | ||
1) If `arr` is a ExtensionArray or Index, `arr` will be | ||
returned as is. | ||
2) If `arr` is a Series, the `_values` attribute will | ||
be reshaped and returned. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import operator | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import pandas as pd | ||
import pandas.util.testing as tm | ||
from .base import BaseExtensionTests | ||
|
||
|
||
class BaseSetitemTests(BaseExtensionTests): | ||
def test_setitem_scalar_series(self, data): | ||
arr = pd.Series(data) | ||
arr[0] = data[1] | ||
assert arr[0] == data[1] | ||
|
||
def test_setitem_sequence(self, data): | ||
arr = pd.Series(data) | ||
original = data.copy() | ||
|
||
arr[[0, 1]] = [data[1], data[0]] | ||
assert arr[0] == original[1] | ||
assert arr[1] == original[0] | ||
|
||
@pytest.mark.parametrize('as_array', [True, False]) | ||
def test_setitem_sequence_mismatched_length_raises(self, data, as_array): | ||
ser = pd.Series(data) | ||
value = [data[0]] | ||
if as_array: | ||
value = type(data)(value) | ||
|
||
xpr = 'cannot set using a {} indexer with a different length' | ||
with tm.assert_raises_regex(ValueError, xpr.format('list-like')): | ||
ser[[0, 1]] = value | ||
|
||
with tm.assert_raises_regex(ValueError, xpr.format('slice')): | ||
ser[slice(3)] = value | ||
|
||
def test_setitem_empty_indxer(self, data): | ||
ser = pd.Series(data) | ||
original = ser.copy() | ||
ser[[]] = [] | ||
self.assert_series_equal(ser, original) | ||
|
||
def test_setitem_sequence_broadcasts(self, data): | ||
arr = pd.Series(data) | ||
|
||
arr[[0, 1]] = data[2] | ||
assert arr[0] == data[2] | ||
assert arr[1] == data[2] | ||
|
||
@pytest.mark.parametrize('setter', ['loc', 'iloc']) | ||
def test_setitem_scalar(self, data, setter): | ||
arr = pd.Series(data) | ||
setter = getattr(arr, setter) | ||
operator.setitem(setter, 0, data[1]) | ||
assert arr[0] == data[1] | ||
|
||
def test_setitem_loc_scalar_mixed(self, data): | ||
df = pd.DataFrame({"A": np.arange(len(data)), "B": data}) | ||
df.loc[0, 'B'] = data[1] | ||
assert df.loc[0, 'B'] == data[1] | ||
|
||
def test_setitem_loc_scalar_single(self, data): | ||
df = pd.DataFrame({"B": data}) | ||
df.loc[10, 'B'] = data[1] | ||
assert df.loc[10, 'B'] == data[1] | ||
|
||
def test_setitem_loc_scalar_multiple_homogoneous(self, data): | ||
df = pd.DataFrame({"A": data, "B": data}) | ||
df.loc[10, 'B'] = data[1] | ||
assert df.loc[10, 'B'] == data[1] | ||
|
||
def test_setitem_iloc_scalar_mixed(self, data): | ||
df = pd.DataFrame({"A": np.arange(len(data)), "B": data}) | ||
df.iloc[0, 1] = data[1] | ||
assert df.loc[0, 'B'] == data[1] | ||
|
||
def test_setitem_iloc_scalar_single(self, data): | ||
df = pd.DataFrame({"B": data}) | ||
df.iloc[10, 0] = data[1] | ||
assert df.loc[10, 'B'] == data[1] | ||
|
||
def test_setitem_iloc_scalar_multiple_homogoneous(self, data): | ||
df = pd.DataFrame({"A": data, "B": data}) | ||
df.iloc[10, 1] = data[1] | ||
assert df.loc[10, 'B'] == data[1] | ||
|
||
@pytest.mark.parametrize('as_callable', [True, False]) | ||
@pytest.mark.parametrize('setter', ['loc', None]) | ||
def test_setitem_mask_aligned(self, data, as_callable, setter): | ||
ser = pd.Series(data) | ||
mask = np.zeros(len(data), dtype=bool) | ||
mask[:2] = True | ||
|
||
if as_callable: | ||
mask2 = lambda x: mask | ||
else: | ||
mask2 = mask | ||
|
||
if setter: | ||
# loc | ||
target = getattr(ser, setter) | ||
else: | ||
# Series.__setitem__ | ||
target = ser | ||
|
||
operator.setitem(target, mask2, data[5:7]) | ||
|
||
ser[mask2] = data[5:7] | ||
assert ser[0] == data[5] | ||
assert ser[1] == data[6] | ||
|
||
@pytest.mark.parametrize('setter', ['loc', None]) | ||
def test_setitem_mask_broadcast(self, data, setter): | ||
ser = pd.Series(data) | ||
mask = np.zeros(len(data), dtype=bool) | ||
mask[:2] = True | ||
|
||
if setter: # loc | ||
target = getattr(ser, setter) | ||
else: # __setitem__ | ||
target = ser | ||
|
||
operator.setitem(target, mask, data[10]) | ||
assert ser[0] == data[10] | ||
assert ser[1] == data[10] | ||
|
||
def test_setitem_expand_columns(self, data): | ||
df = pd.DataFrame({"A": data}) | ||
result = df.copy() | ||
result['B'] = 1 | ||
expected = pd.DataFrame({"A": data, "B": [1] * len(data)}) | ||
self.assert_frame_equal(result, expected) | ||
|
||
result = df.copy() | ||
result.loc[:, 'B'] = 1 | ||
self.assert_frame_equal(result, expected) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add overwriting the existing int |
||
|
||
# overwrite with new type | ||
result['B'] = data | ||
expected = pd.DataFrame({"A": data, "B": data}) | ||
self.assert_frame_equal(result, expected) | ||
|
||
def test_setitem_expand_with_extension(self, data): | ||
df = pd.DataFrame({"A": [1] * len(data)}) | ||
result = df.copy() | ||
result['B'] = data | ||
expected = pd.DataFrame({"A": [1] * len(data), "B": data}) | ||
self.assert_frame_equal(result, expected) | ||
|
||
result = df.copy() | ||
result.loc[:, 'B'] = data | ||
self.assert_frame_equal(result, expected) | ||
|
||
def test_setitem_frame_invalid_length(self, data): | ||
df = pd.DataFrame({"A": [1] * len(data)}) | ||
xpr = "Length of values does not match length of index" | ||
with tm.assert_raises_regex(ValueError, xpr): | ||
df['B'] = data[:5] | ||
|
||
@pytest.mark.xfail(reason="GH-20441: setitem on extension types.") | ||
def test_setitem_tuple_index(self, data): | ||
s = pd.Series(data[:2], index=[(0, 0), (0, 1)]) | ||
expected = pd.Series(data.take([1, 1]), index=s.index) | ||
s[(0, 1)] = data[1] | ||
self.assert_series_equal(s, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was refactored out to
pandas/util/_validators
since I needed it inExtensionArray.setitem
.