diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 3a0e1b7568c91..d3e065f0fc893 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -795,6 +795,7 @@ ExtensionArray - Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`) - Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`) +- Fixed a bug where some properties of subclasses of :class:`PandasExtensionDtype` where improperly cached (:issue:`40329`) - Other diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 2785874878c96..ddd6a76cb83dc 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -81,7 +81,7 @@ class PandasExtensionDtype(ExtensionDtype): base: DtypeObj | None = None isbuiltin = 0 isnative = 0 - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __str__(self) -> str_type: """ @@ -105,7 +105,7 @@ def __getstate__(self) -> dict[str_type, Any]: @classmethod def reset_cache(cls) -> None: """ clear the cache """ - cls._cache = {} + cls._cache_dtypes = {} class CategoricalDtypeType(type): @@ -177,7 +177,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype): str = "|O08" base = np.dtype("O") _metadata = ("categories", "ordered") - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __init__(self, categories=None, ordered: Ordered = False): self._finalize(categories, ordered, fastpath=False) @@ -671,7 +671,7 @@ class DatetimeTZDtype(PandasExtensionDtype): na_value = NaT _metadata = ("unit", "tz") _match = re.compile(r"(datetime64|M8)\[(?P.+), (?P.+)\]") - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __init__(self, unit: str_type | DatetimeTZDtype = "ns", tz=None): if isinstance(unit, DatetimeTZDtype): @@ -837,7 +837,7 @@ class PeriodDtype(dtypes.PeriodDtypeBase, PandasExtensionDtype): num = 102 _metadata = ("freq",) _match = re.compile(r"(P|p)eriod\[(?P.+)\]") - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __new__(cls, freq=None): """ @@ -859,12 +859,12 @@ def __new__(cls, freq=None): freq = cls._parse_dtype_strict(freq) try: - return cls._cache[freq.freqstr] + return cls._cache_dtypes[freq.freqstr] except KeyError: dtype_code = freq._period_dtype_code u = dtypes.PeriodDtypeBase.__new__(cls, dtype_code) u._freq = freq - cls._cache[freq.freqstr] = u + cls._cache_dtypes[freq.freqstr] = u return u def __reduce__(self): @@ -1042,7 +1042,7 @@ class IntervalDtype(PandasExtensionDtype): _match = re.compile( r"(I|i)nterval\[(?P[^,]+)(, (?P(right|left|both|neither)))?\]" ) - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __new__(cls, subtype=None, closed: str_type | None = None): from pandas.core.dtypes.common import ( @@ -1099,12 +1099,12 @@ def __new__(cls, subtype=None, closed: str_type | None = None): key = str(subtype) + str(closed) try: - return cls._cache[key] + return cls._cache_dtypes[key] except KeyError: u = object.__new__(cls) u._subtype = subtype u._closed = closed - cls._cache[key] = u + cls._cache_dtypes[key] = u return u @property diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index 51a7969162abf..abb29ce66fd34 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -66,7 +66,7 @@ def test_pickle(self, dtype): # clear the cache type(dtype).reset_cache() - assert not len(dtype._cache) + assert not len(dtype._cache_dtypes) # force back to the cache result = tm.round_trip_pickle(dtype) @@ -74,7 +74,7 @@ def test_pickle(self, dtype): # Because PeriodDtype has a cython class as a base class, # it has different pickle semantics, and its cache is re-populated # on un-pickling. - assert not len(dtype._cache) + assert not len(dtype._cache_dtypes) assert result == dtype @@ -791,14 +791,14 @@ def test_basic_dtype(self): def test_caching(self): IntervalDtype.reset_cache() dtype = IntervalDtype("int64", "right") - assert len(IntervalDtype._cache) == 1 + assert len(IntervalDtype._cache_dtypes) == 1 IntervalDtype("interval") - assert len(IntervalDtype._cache) == 2 + assert len(IntervalDtype._cache_dtypes) == 2 IntervalDtype.reset_cache() tm.round_trip_pickle(dtype) - assert len(IntervalDtype._cache) == 0 + assert len(IntervalDtype._cache_dtypes) == 0 def test_not_string(self): # GH30568: though IntervalDtype has object kind, it cannot be string