|
| 1 | +import pytest |
| 2 | + |
| 3 | +from .base import BaseExtensionTests |
| 4 | + |
| 5 | + |
| 6 | +class BaseSetitemTests(BaseExtensionTests): |
| 7 | + """Tests for ExtensionArray.__setitem__""" |
| 8 | + |
| 9 | + def test_set_scalar(self, data): |
| 10 | + expected = data.take([1, 1]) |
| 11 | + subset = data[:2].copy() |
| 12 | + |
| 13 | + subset[0] = data[1] |
| 14 | + self.assert_extension_array_equal(subset, expected) |
| 15 | + |
| 16 | + def test_set_mask_scalar(self, data): |
| 17 | + expected = data.take([1, 1, 2, 1]) |
| 18 | + subset = data[:4].copy() |
| 19 | + |
| 20 | + subset[[True, True, False, True]] = data[1] |
| 21 | + self.assert_extension_array_equal(subset, expected) |
| 22 | + |
| 23 | + @pytest.mark.parametrize('key', [ |
| 24 | + [False, True, True, True], |
| 25 | + [1, 2, 3], |
| 26 | + ], ids=['mask', 'fancy']) |
| 27 | + def test_set_array(self, key, data): |
| 28 | + expected = data.take([0, 2, 2, 1]) |
| 29 | + value = data.take([2, 2, 1]) |
| 30 | + subset = data[:4].copy() |
| 31 | + |
| 32 | + subset[key] = value |
| 33 | + self.assert_extension_array_equal(subset, expected) |
| 34 | + |
| 35 | + def test_bad_mask_bad_length_raise(self, data): |
| 36 | + value = data[0] |
| 37 | + with pytest.raises(IndexError): |
| 38 | + data[[True, False]] = value |
0 commit comments