Skip to content

Commit de96a61

Browse files
committed
Added IntervalArray.__setitem__
1 parent 11d97db commit de96a61

File tree

6 files changed

+71
-0
lines changed

6 files changed

+71
-0
lines changed

pandas/core/arrays/interval.py

+20
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,26 @@ def __getitem__(self, value):
456456

457457
return self._shallow_copy(left, right)
458458

459+
def __setitem__(self, key, value):
460+
if not (is_interval_dtype(value) or isinstance(value, Interval)):
461+
msg = "'value' should be an interval type, got {} instead."
462+
raise TypeError(msg.format(type(value)))
463+
464+
if value.closed != self.closed:
465+
msg = "'value.closed' ({}) does not match {}."
466+
raise ValueError(value.closed, self.closed)
467+
468+
# Need to ensure that left and right are updated atomically, so we're
469+
# forced to copy, update the copy, and swap in the new values.
470+
left = self.left.copy(deep=True)
471+
right = self.right.copy(deep=True)
472+
473+
left.values[key] = value.left
474+
right.values[key] = value.right
475+
476+
self._left = left
477+
self._right = right
478+
459479
def fillna(self, value=None, method=None, limit=None):
460480
if method is not None:
461481
raise TypeError('Filling by method is not supported for '

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ class TestMyDtype(BaseDtypeTests):
4949
from .methods import BaseMethodsTests # noqa
5050
from .missing import BaseMissingTests # noqa
5151
from .reshaping import BaseReshapingTests # noqa
52+
from .setitem import BaseSetitemTests # noqa
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
3+
from .base import BaseExtensionTests
4+
5+
6+
class BaseSetitemTests(BaseExtensionTests):
7+
"""Tests for ExtensionArray.__setitem__"""
8+
9+
def test_set_scalar(self, data):
10+
expected = data.take([1, 1])
11+
subset = data[:2].copy()
12+
13+
subset[0] = data[1]
14+
self.assert_extension_array_equal(subset, expected)
15+
16+
def test_set_mask_scalar(self, data):
17+
expected = data.take([1, 1, 2, 1])
18+
subset = data[:4].copy()
19+
20+
subset[[True, True, False, True]] = data[1]
21+
self.assert_extension_array_equal(subset, expected)
22+
23+
@pytest.mark.parametrize('key', [
24+
[False, True, True, True],
25+
[1, 2, 3],
26+
], ids=['mask', 'fancy'])
27+
def test_set_array(self, key, data):
28+
expected = data.take([0, 2, 2, 1])
29+
value = data.take([2, 2, 1])
30+
subset = data[:4].copy()
31+
32+
subset[key] = value
33+
self.assert_extension_array_equal(subset, expected)
34+
35+
def test_bad_mask_bad_length_raise(self, data):
36+
value = data[0]
37+
with pytest.raises(IndexError):
38+
data[[True, False]] = value

pandas/tests/extension/category/test_categorical.py

+4
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,7 @@ def test_value_counts(self, all_data, dropna):
106106

107107
class TestCasting(base.BaseCastingTests):
108108
pass
109+
110+
111+
class TestSetitem(base.BaseSetitemTests):
112+
pass

pandas/tests/extension/decimal/test_decimal.py

+4
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
131131
pass
132132

133133

134+
class TestSetitem(BaseDecimal, base.BaseSetitemTests):
135+
pass
136+
137+
134138
def test_series_constructor_coerce_data_to_extension_dtype_raises():
135139
xpr = ("Cannot cast data to extension dtype 'decimal'. Pass the "
136140
"extension array directly.")

pandas/tests/extension/test_interval.py

+4
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ class TestReshaping(BaseInterval, base.BaseReshapingTests):
118118
pass
119119

120120

121+
class TestSetitem(BaseInterval, base.BaseSetitemTests):
122+
pass
123+
124+
121125
def test_repr_matches():
122126
idx = IntervalIndex.from_breaks([1, 2, 3])
123127
a = repr(idx)

0 commit comments

Comments
 (0)