From f05b9ae3af1f38d997f8181dba69bc3f5635222f Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 4 Aug 2023 15:04:31 -0700 Subject: [PATCH] ENH: Dtype._is_immutable --- pandas/core/dtypes/base.py | 10 ++++++++++ pandas/core/dtypes/dtypes.py | 2 ++ pandas/core/internals/managers.py | 3 +-- pandas/tests/extension/base/interface.py | 7 +++++++ pandas/tests/extension/base/reshaping.py | 6 ++++++ pandas/tests/extension/base/setitem.py | 18 ++++++++++++++++++ pandas/tests/extension/test_sparse.py | 19 +++---------------- 7 files changed, 47 insertions(+), 18 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index c6c162001d147..bc776434b2e6e 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -396,6 +396,16 @@ def _can_hold_na(self) -> bool: """ return True + @property + def _is_immutable(self) -> bool: + """ + Can arrays with this dtype be modified with __setitem__? If not, return + True. + + Immutable arrays are expected to raise TypeError on __setitem__ calls. + """ + return False + class StorageExtensionDtype(ExtensionDtype): """ExtensionDtype that may be backed by more than one implementation.""" diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 0d3e955696d81..53f0fb2843653 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -1608,6 +1608,8 @@ class SparseDtype(ExtensionDtype): 0.3333333333333333 """ + _is_immutable = True + # We include `_is_na_fill_value` in the metadata to avoid hash collisions # between SparseDtype(float, 0.0) and SparseDtype(float, nan). # Without is_na_fill_value in the comparison, those would be equal since diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index 05577fb971061..1533dc77321cb 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -39,7 +39,6 @@ from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, ExtensionDtype, - SparseDtype, ) from pandas.core.dtypes.generic import ( ABCDataFrame, @@ -943,7 +942,7 @@ def fast_xs(self, loc: int) -> SingleBlockManager: n = len(self) # GH#46406 - immutable_ea = isinstance(dtype, SparseDtype) + immutable_ea = isinstance(dtype, ExtensionDtype) and dtype._is_immutable if isinstance(dtype, ExtensionDtype) and not immutable_ea: cls = dtype.construct_array_type() diff --git a/pandas/tests/extension/base/interface.py b/pandas/tests/extension/base/interface.py index 3e8a754c8c527..92d50e5bd9a66 100644 --- a/pandas/tests/extension/base/interface.py +++ b/pandas/tests/extension/base/interface.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from pandas.core.dtypes.common import is_extension_array_dtype from pandas.core.dtypes.dtypes import ExtensionDtype @@ -102,6 +103,9 @@ def test_copy(self, data): assert data[0] != data[1] result = data.copy() + if data.dtype._is_immutable: + pytest.skip("test_copy assumes mutability") + data[1] = data[0] assert result[1] != result[0] @@ -114,6 +118,9 @@ def test_view(self, data): assert result is not data assert type(result) == type(data) + if data.dtype._is_immutable: + pytest.skip("test_view assumes mutability") + result[1] = result[0] assert data[1] == data[0] diff --git a/pandas/tests/extension/base/reshaping.py b/pandas/tests/extension/base/reshaping.py index 3f89ef5395006..ea618ead7c84d 100644 --- a/pandas/tests/extension/base/reshaping.py +++ b/pandas/tests/extension/base/reshaping.py @@ -334,6 +334,9 @@ def test_ravel(self, data): result = data.ravel() assert type(result) == type(data) + if data.dtype._is_immutable: + pytest.skip("test_ravel assumes mutability") + # Check that we have a view, not a copy result[0] = result[1] assert data[0] == data[1] @@ -348,6 +351,9 @@ def test_transpose(self, data): # If we ever _did_ support 2D, shape should be reversed assert result.shape == data.shape[::-1] + if data.dtype._is_immutable: + pytest.skip("test_transpose assumes mutability") + # Check that we have a view, not a copy result[0] = result[1] assert data[0] == data[1] diff --git a/pandas/tests/extension/base/setitem.py b/pandas/tests/extension/base/setitem.py index 1085ada920ccc..66842dbc18145 100644 --- a/pandas/tests/extension/base/setitem.py +++ b/pandas/tests/extension/base/setitem.py @@ -36,6 +36,24 @@ def full_indexer(self, request): """ return request.param + @pytest.fixture(autouse=True) + def skip_if_immutable(self, dtype, request): + if dtype._is_immutable: + node = request.node + if node.name.split("[")[0] == "test_is_immutable": + # This fixture is auto-used, but we want to not-skip + # test_is_immutable. + return + pytest.skip("__setitem__ test not applicable with immutable dtype") + + def test_is_immutable(self, data): + if data.dtype._is_immutable: + with pytest.raises(TypeError): + data[0] = data[0] + else: + data[0] = data[1] + assert data[0] == data[1] + def test_setitem_scalar_series(self, data, box_in_series): if box_in_series: data = pd.Series(data) diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index a39133c784380..90997160f2b08 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -108,10 +108,6 @@ def _check_unsupported(self, data): if data.dtype == SparseDtype(int, 0): pytest.skip("Can't store nan in int array.") - @pytest.mark.xfail(reason="SparseArray does not support setitem") - def test_ravel(self, data): - super().test_ravel(data) - class TestDtype(BaseSparseTests, base.BaseDtypeTests): def test_array_type_with_arg(self, data, dtype): @@ -119,13 +115,7 @@ def test_array_type_with_arg(self, data, dtype): class TestInterface(BaseSparseTests, base.BaseInterfaceTests): - def test_copy(self, data): - # __setitem__ does not work, so we only have a smoke-test - data.copy() - - def test_view(self, data): - # __setitem__ does not work, so we only have a smoke-test - data.view() + pass class TestConstructors(BaseSparseTests, base.BaseConstructorsTests): @@ -185,10 +175,6 @@ def test_merge(self, data, na_value): self._check_unsupported(data) super().test_merge(data, na_value) - @pytest.mark.xfail(reason="SparseArray does not support setitem") - def test_transpose(self, data): - super().test_transpose(data) - class TestGetitem(BaseSparseTests, base.BaseGetitemTests): def test_get(self, data): @@ -204,7 +190,8 @@ def test_reindex(self, data, na_value): super().test_reindex(data, na_value) -# Skipping TestSetitem, since we don't implement it. +class TestSetitem(BaseSparseTests, base.BaseSetitemTests): + pass class TestIndex(base.BaseIndexTests):