Skip to content

Commit 67e1c09

Browse files
authored
PERF: Cache hashing of categories (#40193)
1 parent 9caf503 commit 67e1c09

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

doc/source/whatsnew/v1.3.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ Performance improvements
584584
- Performance improvement in :class:`Styler` where render times are more than 50% reduced (:issue:`39972` :issue:`39952`)
585585
- Performance improvement in :meth:`core.window.ewm.ExponentialMovingWindow.mean` with ``times`` (:issue:`39784`)
586586
- Performance improvement in :meth:`.GroupBy.apply` when requiring the python fallback implementation (:issue:`40176`)
587+
- Performance improvement for concatenation of data with type :class:`CategoricalDtype` (:issue:`40193`)
587588

588589
.. ---------------------------------------------------------------------------
589590
@@ -813,6 +814,7 @@ ExtensionArray
813814

814815
- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`)
815816
- Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`)
817+
- Fixed a bug where some properties of subclasses of :class:`PandasExtensionDtype` where improperly cached (:issue:`40329`)
816818
-
817819

818820
Other

pandas/core/dtypes/dtypes.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytz
1616

1717
from pandas._libs.interval import Interval
18+
from pandas._libs.properties import cache_readonly
1819
from pandas._libs.tslibs import (
1920
BaseOffset,
2021
NaT,
@@ -81,7 +82,7 @@ class PandasExtensionDtype(ExtensionDtype):
8182
base: DtypeObj | None = None
8283
isbuiltin = 0
8384
isnative = 0
84-
_cache: dict[str_type, PandasExtensionDtype] = {}
85+
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
8586

8687
def __str__(self) -> str_type:
8788
"""
@@ -105,7 +106,7 @@ def __getstate__(self) -> dict[str_type, Any]:
105106
@classmethod
106107
def reset_cache(cls) -> None:
107108
""" clear the cache """
108-
cls._cache = {}
109+
cls._cache_dtypes = {}
109110

110111

111112
class CategoricalDtypeType(type):
@@ -177,7 +178,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
177178
str = "|O08"
178179
base = np.dtype("O")
179180
_metadata = ("categories", "ordered")
180-
_cache: dict[str_type, PandasExtensionDtype] = {}
181+
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
181182

182183
def __init__(self, categories=None, ordered: Ordered = False):
183184
self._finalize(categories, ordered, fastpath=False)
@@ -355,7 +356,7 @@ def __hash__(self) -> int:
355356
else:
356357
return -2
357358
# We *do* want to include the real self.ordered here
358-
return int(self._hash_categories(self.categories, self.ordered))
359+
return int(self._hash_categories)
359360

360361
def __eq__(self, other: Any) -> bool:
361362
"""
@@ -429,14 +430,17 @@ def __repr__(self) -> str_type:
429430
data = data.rstrip(", ")
430431
return f"CategoricalDtype(categories={data}, ordered={self.ordered})"
431432

432-
@staticmethod
433-
def _hash_categories(categories, ordered: Ordered = True) -> int:
433+
@cache_readonly
434+
def _hash_categories(self) -> int:
434435
from pandas.core.util.hashing import (
435436
combine_hash_arrays,
436437
hash_array,
437438
hash_tuples,
438439
)
439440

441+
categories = self.categories
442+
ordered = self.ordered
443+
440444
if len(categories) and isinstance(categories[0], tuple):
441445
# assumes if any individual category is a tuple, then all our. ATM
442446
# I don't really want to support just some of the categories being
@@ -671,7 +675,7 @@ class DatetimeTZDtype(PandasExtensionDtype):
671675
na_value = NaT
672676
_metadata = ("unit", "tz")
673677
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
674-
_cache: dict[str_type, PandasExtensionDtype] = {}
678+
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
675679

676680
def __init__(self, unit: str_type | DatetimeTZDtype = "ns", tz=None):
677681
if isinstance(unit, DatetimeTZDtype):
@@ -837,7 +841,7 @@ class PeriodDtype(dtypes.PeriodDtypeBase, PandasExtensionDtype):
837841
num = 102
838842
_metadata = ("freq",)
839843
_match = re.compile(r"(P|p)eriod\[(?P<freq>.+)\]")
840-
_cache: dict[str_type, PandasExtensionDtype] = {}
844+
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
841845

842846
def __new__(cls, freq=None):
843847
"""
@@ -859,12 +863,12 @@ def __new__(cls, freq=None):
859863
freq = cls._parse_dtype_strict(freq)
860864

861865
try:
862-
return cls._cache[freq.freqstr]
866+
return cls._cache_dtypes[freq.freqstr]
863867
except KeyError:
864868
dtype_code = freq._period_dtype_code
865869
u = dtypes.PeriodDtypeBase.__new__(cls, dtype_code)
866870
u._freq = freq
867-
cls._cache[freq.freqstr] = u
871+
cls._cache_dtypes[freq.freqstr] = u
868872
return u
869873

870874
def __reduce__(self):
@@ -1042,7 +1046,7 @@ class IntervalDtype(PandasExtensionDtype):
10421046
_match = re.compile(
10431047
r"(I|i)nterval\[(?P<subtype>[^,]+)(, (?P<closed>(right|left|both|neither)))?\]"
10441048
)
1045-
_cache: dict[str_type, PandasExtensionDtype] = {}
1049+
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
10461050

10471051
def __new__(cls, subtype=None, closed: str_type | None = None):
10481052
from pandas.core.dtypes.common import (
@@ -1099,12 +1103,12 @@ def __new__(cls, subtype=None, closed: str_type | None = None):
10991103

11001104
key = str(subtype) + str(closed)
11011105
try:
1102-
return cls._cache[key]
1106+
return cls._cache_dtypes[key]
11031107
except KeyError:
11041108
u = object.__new__(cls)
11051109
u._subtype = subtype
11061110
u._closed = closed
1107-
cls._cache[key] = u
1111+
cls._cache_dtypes[key] = u
11081112
return u
11091113

11101114
@property

pandas/tests/dtypes/test_dtypes.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def test_pickle(self, dtype):
6666

6767
# clear the cache
6868
type(dtype).reset_cache()
69-
assert not len(dtype._cache)
69+
assert not len(dtype._cache_dtypes)
7070

7171
# force back to the cache
7272
result = tm.round_trip_pickle(dtype)
7373
if not isinstance(dtype, PeriodDtype):
7474
# Because PeriodDtype has a cython class as a base class,
7575
# it has different pickle semantics, and its cache is re-populated
7676
# on un-pickling.
77-
assert not len(dtype._cache)
77+
assert not len(dtype._cache_dtypes)
7878
assert result == dtype
7979

8080

@@ -791,14 +791,14 @@ def test_basic_dtype(self):
791791
def test_caching(self):
792792
IntervalDtype.reset_cache()
793793
dtype = IntervalDtype("int64", "right")
794-
assert len(IntervalDtype._cache) == 1
794+
assert len(IntervalDtype._cache_dtypes) == 1
795795

796796
IntervalDtype("interval")
797-
assert len(IntervalDtype._cache) == 2
797+
assert len(IntervalDtype._cache_dtypes) == 2
798798

799799
IntervalDtype.reset_cache()
800800
tm.round_trip_pickle(dtype)
801-
assert len(IntervalDtype._cache) == 0
801+
assert len(IntervalDtype._cache_dtypes) == 0
802802

803803
def test_not_string(self):
804804
# GH30568: though IntervalDtype has object kind, it cannot be string

0 commit comments

Comments
 (0)