From 8d43cbabfb60a31561f168276b7d177e374ed415 Mon Sep 17 00:00:00 2001 From: Jeff Reback Date: Tue, 2 May 2017 18:58:21 -0400 Subject: [PATCH] COMPAT: ensure proper extension dtype's don't pickle the cache --- pandas/core/dtypes/dtypes.py | 28 +++++++- pandas/tests/dtypes/test_dtypes.py | 109 +++++++++++++++++++++++++---- 2 files changed, 119 insertions(+), 18 deletions(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 59c23addd418e..561f1951a4151 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -24,6 +24,7 @@ class ExtensionDtype(object): isbuiltin = 0 isnative = 0 _metadata = [] + _cache = {} def __unicode__(self): return self.name @@ -71,6 +72,15 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __getstate__(self): + # pickle support; we don't want to pickle the cache + return {k: getattr(self, k, None) for k in self._metadata} + + @classmethod + def reset_cache(cls): + """ clear the cache """ + cls._cache = {} + @classmethod def is_dtype(cls, dtype): """ Return a boolean if the passed type is an actual dtype that @@ -110,6 +120,7 @@ class CategoricalDtype(ExtensionDtype): kind = 'O' str = '|O08' base = np.dtype('O') + _metadata = [] _cache = {} def __new__(cls): @@ -408,9 +419,15 @@ def __new__(cls, subtype=None): if isinstance(subtype, IntervalDtype): return subtype - elif subtype is None or (isinstance(subtype, compat.string_types) and - subtype == 'interval'): - subtype = None + elif subtype is None: + # we are called as an empty constructor + # generally for pickle compat + u = object.__new__(cls) + u.subtype = None + return u + elif (isinstance(subtype, compat.string_types) and + subtype == 'interval'): + subtype = '' else: if isinstance(subtype, compat.string_types): m = cls._match.search(subtype) @@ -423,6 +440,11 @@ def __new__(cls, subtype=None): except TypeError: raise ValueError("could not construct IntervalDtype") + if subtype is None: + u = object.__new__(cls) + u.subtype = None + return u + try: return cls._cache[str(subtype)] except KeyError: diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index da3120145fe38..fb20571213c15 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -23,6 +23,9 @@ class Base(object): + def setup_method(self, method): + self.dtype = self.create() + def test_hash(self): hash(self.dtype) @@ -37,14 +40,38 @@ def test_numpy_informed(self): assert not np.str_ == self.dtype def test_pickle(self): + # make sure our cache is NOT pickled + + # clear the cache + type(self.dtype).reset_cache() + assert not len(self.dtype._cache) + + # force back to the cache result = tm.round_trip_pickle(self.dtype) + assert not len(self.dtype._cache) assert result == self.dtype -class TestCategoricalDtype(Base, tm.TestCase): +class TestCategoricalDtype(Base): + + def create(self): + return CategoricalDtype() + + def test_pickle(self): + # make sure our cache is NOT pickled + + # clear the cache + type(self.dtype).reset_cache() + assert not len(self.dtype._cache) - def setUp(self): - self.dtype = CategoricalDtype() + # force back to the cache + result = tm.round_trip_pickle(self.dtype) + + # we are a singular object so we are added + # back to the cache upon unpickling + # this is to ensure object identity + assert len(self.dtype._cache) == 1 + assert result == self.dtype def test_hash_vs_equality(self): # make sure that we satisfy is semantics @@ -93,10 +120,10 @@ def test_basic(self): assert not is_categorical(1.0) -class TestDatetimeTZDtype(Base, tm.TestCase): +class TestDatetimeTZDtype(Base): - def setUp(self): - self.dtype = DatetimeTZDtype('ns', 'US/Eastern') + def create(self): + return DatetimeTZDtype('ns', 'US/Eastern') def test_hash_vs_equality(self): # make sure that we satisfy is semantics @@ -209,10 +236,24 @@ def test_empty(self): str(dt) -class TestPeriodDtype(Base, tm.TestCase): +class TestPeriodDtype(Base): - def setUp(self): - self.dtype = PeriodDtype('D') + def create(self): + return PeriodDtype('D') + + def test_hash_vs_equality(self): + # make sure that we satisfy is semantics + dtype = self.dtype + dtype2 = PeriodDtype('D') + dtype3 = PeriodDtype(dtype2) + assert dtype == dtype2 + assert dtype2 == dtype + assert dtype3 == dtype + assert dtype is dtype2 + assert dtype2 is dtype + assert dtype3 is dtype + assert hash(dtype) == hash(dtype2) + assert hash(dtype) == hash(dtype3) def test_construction(self): with pytest.raises(ValueError): @@ -338,11 +379,37 @@ def test_not_string(self): assert not is_string_dtype(PeriodDtype('D')) -class TestIntervalDtype(Base, tm.TestCase): +class TestIntervalDtype(Base): + + def create(self): + return IntervalDtype('int64') + + def test_hash_vs_equality(self): + # make sure that we satisfy is semantics + dtype = self.dtype + dtype2 = IntervalDtype('int64') + dtype3 = IntervalDtype(dtype2) + assert dtype == dtype2 + assert dtype2 == dtype + assert dtype3 == dtype + assert dtype is dtype2 + assert dtype2 is dtype + assert dtype3 is dtype + assert hash(dtype) == hash(dtype2) + assert hash(dtype) == hash(dtype3) - # TODO: placeholder - def setUp(self): - self.dtype = IntervalDtype('int64') + dtype1 = IntervalDtype('interval') + dtype2 = IntervalDtype(dtype1) + dtype3 = IntervalDtype('interval') + assert dtype2 == dtype1 + assert dtype2 == dtype2 + assert dtype2 == dtype3 + assert dtype2 is dtype1 + assert dtype2 is dtype2 + assert dtype2 is dtype3 + assert hash(dtype2) == hash(dtype1) + assert hash(dtype2) == hash(dtype2) + assert hash(dtype2) == hash(dtype3) def test_construction(self): with pytest.raises(ValueError): @@ -356,9 +423,9 @@ def test_construction(self): def test_construction_generic(self): # generic i = IntervalDtype('interval') - assert i.subtype is None + assert i.subtype == '' assert is_interval_dtype(i) - assert str(i) == 'interval' + assert str(i) == 'interval[]' i = IntervalDtype() assert i.subtype is None @@ -445,3 +512,15 @@ def test_basic_dtype(self): assert not is_interval_dtype(np.object_) assert not is_interval_dtype(np.int64) assert not is_interval_dtype(np.float64) + + def test_caching(self): + IntervalDtype.reset_cache() + dtype = IntervalDtype("int64") + assert len(IntervalDtype._cache) == 1 + + IntervalDtype("interval") + assert len(IntervalDtype._cache) == 2 + + IntervalDtype.reset_cache() + tm.round_trip_pickle(dtype) + assert len(IntervalDtype._cache) == 0