Skip to content

PERF: Cache hashing of categories #40193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ Performance improvements
- Performance improvement in :class:`Styler` where render times are more than 50% reduced (:issue:`39972` :issue:`39952`)
- Performance improvement in :meth:`core.window.ewm.ExponentialMovingWindow.mean` with ``times`` (:issue:`39784`)
- Performance improvement in :meth:`.GroupBy.apply` when requiring the python fallback implementation (:issue:`40176`)
- Performance improvement for concatenation of data with type :class:`CategoricalDtype` (:issue:`40193`)

.. ---------------------------------------------------------------------------

Expand Down Expand Up @@ -795,6 +796,7 @@ ExtensionArray

- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`)
- Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`)
- Fixed a bug where some properties of subclasses of :class:`PandasExtensionDtype` where improperly cached (:issue:`40329`)
-

Other
Expand Down
30 changes: 17 additions & 13 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytz

from pandas._libs.interval import Interval
from pandas._libs.properties import cache_readonly
from pandas._libs.tslibs import (
BaseOffset,
NaT,
Expand Down Expand Up @@ -81,7 +82,7 @@ class PandasExtensionDtype(ExtensionDtype):
base: DtypeObj | None = None
isbuiltin = 0
isnative = 0
_cache: dict[str_type, PandasExtensionDtype] = {}
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}

def __str__(self) -> str_type:
"""
Expand All @@ -105,7 +106,7 @@ def __getstate__(self) -> dict[str_type, Any]:
@classmethod
def reset_cache(cls) -> None:
""" clear the cache """
cls._cache = {}
cls._cache_dtypes = {}


class CategoricalDtypeType(type):
Expand Down Expand Up @@ -177,7 +178,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
str = "|O08"
base = np.dtype("O")
_metadata = ("categories", "ordered")
_cache: dict[str_type, PandasExtensionDtype] = {}
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}

def __init__(self, categories=None, ordered: Ordered = False):
self._finalize(categories, ordered, fastpath=False)
Expand Down Expand Up @@ -355,7 +356,7 @@ def __hash__(self) -> int:
else:
return -2
# We *do* want to include the real self.ordered here
return int(self._hash_categories(self.categories, self.ordered))
return int(self._hash_categories)

def __eq__(self, other: Any) -> bool:
"""
Expand Down Expand Up @@ -429,14 +430,17 @@ def __repr__(self) -> str_type:
data = data.rstrip(", ")
return f"CategoricalDtype(categories={data}, ordered={self.ordered})"

@staticmethod
def _hash_categories(categories, ordered: Ordered = True) -> int:
@cache_readonly
def _hash_categories(self) -> int:
from pandas.core.util.hashing import (
combine_hash_arrays,
hash_array,
hash_tuples,
)

categories = self.categories
ordered = self.ordered

if len(categories) and isinstance(categories[0], tuple):
# assumes if any individual category is a tuple, then all our. ATM
# I don't really want to support just some of the categories being
Expand Down Expand Up @@ -671,7 +675,7 @@ class DatetimeTZDtype(PandasExtensionDtype):
na_value = NaT
_metadata = ("unit", "tz")
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
_cache: dict[str_type, PandasExtensionDtype] = {}
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}

def __init__(self, unit: str_type | DatetimeTZDtype = "ns", tz=None):
if isinstance(unit, DatetimeTZDtype):
Expand Down Expand Up @@ -837,7 +841,7 @@ class PeriodDtype(dtypes.PeriodDtypeBase, PandasExtensionDtype):
num = 102
_metadata = ("freq",)
_match = re.compile(r"(P|p)eriod\[(?P<freq>.+)\]")
_cache: dict[str_type, PandasExtensionDtype] = {}
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}

def __new__(cls, freq=None):
"""
Expand All @@ -859,12 +863,12 @@ def __new__(cls, freq=None):
freq = cls._parse_dtype_strict(freq)

try:
return cls._cache[freq.freqstr]
return cls._cache_dtypes[freq.freqstr]
except KeyError:
dtype_code = freq._period_dtype_code
u = dtypes.PeriodDtypeBase.__new__(cls, dtype_code)
u._freq = freq
cls._cache[freq.freqstr] = u
cls._cache_dtypes[freq.freqstr] = u
return u

def __reduce__(self):
Expand Down Expand Up @@ -1042,7 +1046,7 @@ class IntervalDtype(PandasExtensionDtype):
_match = re.compile(
r"(I|i)nterval\[(?P<subtype>[^,]+)(, (?P<closed>(right|left|both|neither)))?\]"
)
_cache: dict[str_type, PandasExtensionDtype] = {}
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}

def __new__(cls, subtype=None, closed: str_type | None = None):
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -1099,12 +1103,12 @@ def __new__(cls, subtype=None, closed: str_type | None = None):

key = str(subtype) + str(closed)
try:
return cls._cache[key]
return cls._cache_dtypes[key]
except KeyError:
u = object.__new__(cls)
u._subtype = subtype
u._closed = closed
cls._cache[key] = u
cls._cache_dtypes[key] = u
return u

@property
Expand Down
10 changes: 5 additions & 5 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ def test_pickle(self, dtype):

# clear the cache
type(dtype).reset_cache()
assert not len(dtype._cache)
assert not len(dtype._cache_dtypes)

# force back to the cache
result = tm.round_trip_pickle(dtype)
if not isinstance(dtype, PeriodDtype):
# Because PeriodDtype has a cython class as a base class,
# it has different pickle semantics, and its cache is re-populated
# on un-pickling.
assert not len(dtype._cache)
assert not len(dtype._cache_dtypes)
assert result == dtype


Expand Down Expand Up @@ -791,14 +791,14 @@ def test_basic_dtype(self):
def test_caching(self):
IntervalDtype.reset_cache()
dtype = IntervalDtype("int64", "right")
assert len(IntervalDtype._cache) == 1
assert len(IntervalDtype._cache_dtypes) == 1

IntervalDtype("interval")
assert len(IntervalDtype._cache) == 2
assert len(IntervalDtype._cache_dtypes) == 2

IntervalDtype.reset_cache()
tm.round_trip_pickle(dtype)
assert len(IntervalDtype._cache) == 0
assert len(IntervalDtype._cache_dtypes) == 0

def test_not_string(self):
# GH30568: though IntervalDtype has object kind, it cannot be string
Expand Down