Skip to content

Commit 7dc8f9f

Browse files
committed
COMPAT: ensure proper extension dtype's don't pickle the cache
1 parent 39cc1d0 commit 7dc8f9f

File tree

2 files changed

+77
-9
lines changed

2 files changed

+77
-9
lines changed

pandas/core/dtypes/dtypes.py

+17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ExtensionDtype(object):
2424
isbuiltin = 0
2525
isnative = 0
2626
_metadata = []
27+
_cache = {}
2728

2829
def __unicode__(self):
2930
return self.name
@@ -71,6 +72,15 @@ def __eq__(self, other):
7172
def __ne__(self, other):
7273
return not self.__eq__(other)
7374

75+
def __getstate__(self):
76+
# pickle support; we don't want to pickle the cache
77+
return {k: getattr(self, k, None) for k in self._metadata}
78+
79+
@classmethod
80+
def reset_cache(cls):
81+
""" clear the cache """
82+
cls._cache = {}
83+
7484
@classmethod
7585
def is_dtype(cls, dtype):
7686
""" Return a boolean if the passed type is an actual dtype that
@@ -111,6 +121,7 @@ class CategoricalDtype(ExtensionDtype):
111121
str = '|O08'
112122
base = np.dtype('O')
113123
_cache = {}
124+
_attributes = ['name']
114125

115126
def __new__(cls):
116127

@@ -423,6 +434,12 @@ def __new__(cls, subtype=None):
423434
except TypeError:
424435
raise ValueError("could not construct IntervalDtype")
425436

437+
if subtype is None:
438+
# pickle compat
439+
u = object.__new__(cls)
440+
u.subtype = None
441+
return u
442+
426443
try:
427444
return cls._cache[str(subtype)]
428445
except KeyError:

pandas/tests/dtypes/test_dtypes.py

+60-9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
class Base(object):
2525

26+
def setUp(self):
27+
self.dtype = self.create()
28+
2629
def test_hash(self):
2730
hash(self.dtype)
2831

@@ -37,14 +40,38 @@ def test_numpy_informed(self):
3740
assert not np.str_ == self.dtype
3841

3942
def test_pickle(self):
43+
# make sure our cache is NOT pickled
44+
45+
# clear the cache
46+
type(self.dtype).reset_cache()
47+
assert not len(self.dtype._cache)
48+
49+
# force back to the cache
4050
result = tm.round_trip_pickle(self.dtype)
51+
assert not len(self.dtype._cache)
4152
assert result == self.dtype
4253

4354

4455
class TestCategoricalDtype(Base, tm.TestCase):
4556

46-
def setUp(self):
47-
self.dtype = CategoricalDtype()
57+
def create(self):
58+
return CategoricalDtype()
59+
60+
def test_pickle(self):
61+
# make sure our cache is NOT pickled
62+
63+
# clear the cache
64+
type(self.dtype).reset_cache()
65+
assert not len(self.dtype._cache)
66+
67+
# force back to the cache
68+
result = tm.round_trip_pickle(self.dtype)
69+
70+
# we are a singular object so we are added
71+
# back to the cache upon unpickling
72+
# this is to ensure object identity
73+
assert len(self.dtype._cache) == 1
74+
assert result == self.dtype
4875

4976
def test_hash_vs_equality(self):
5077
# make sure that we satisfy is semantics
@@ -95,8 +122,8 @@ def test_basic(self):
95122

96123
class TestDatetimeTZDtype(Base, tm.TestCase):
97124

98-
def setUp(self):
99-
self.dtype = DatetimeTZDtype('ns', 'US/Eastern')
125+
def create(self):
126+
return DatetimeTZDtype('ns', 'US/Eastern')
100127

101128
def test_hash_vs_equality(self):
102129
# make sure that we satisfy is semantics
@@ -211,8 +238,8 @@ def test_empty(self):
211238

212239
class TestPeriodDtype(Base, tm.TestCase):
213240

214-
def setUp(self):
215-
self.dtype = PeriodDtype('D')
241+
def create(self):
242+
return PeriodDtype('D')
216243

217244
def test_construction(self):
218245
with pytest.raises(ValueError):
@@ -340,9 +367,22 @@ def test_not_string(self):
340367

341368
class TestIntervalDtype(Base, tm.TestCase):
342369

343-
# TODO: placeholder
344-
def setUp(self):
345-
self.dtype = IntervalDtype('int64')
370+
def create(self):
371+
return IntervalDtype('int64')
372+
373+
def test_hash_vs_equality(self):
374+
# make sure that we satisfy is semantics
375+
dtype = self.dtype
376+
dtype2 = IntervalDtype('int64')
377+
dtype3 = IntervalDtype(dtype2)
378+
assert dtype == dtype2
379+
assert dtype2 == dtype
380+
assert dtype3 == dtype
381+
assert dtype is dtype2
382+
assert dtype2 is dtype
383+
assert dtype3 is dtype
384+
assert hash(dtype) == hash(dtype2)
385+
assert hash(dtype) == hash(dtype3)
346386

347387
def test_construction(self):
348388
with pytest.raises(ValueError):
@@ -445,3 +485,14 @@ def test_basic_dtype(self):
445485
assert not is_interval_dtype(np.object_)
446486
assert not is_interval_dtype(np.int64)
447487
assert not is_interval_dtype(np.float64)
488+
489+
def test_caching(self):
490+
IntervalDtype.reset_cache()
491+
dtype = IntervalDtype("int64")
492+
assert len(IntervalDtype._cache) == 1
493+
494+
IntervalDtype("interval")
495+
assert len(IntervalDtype._cache) == 1
496+
497+
tm.round_trip_pickle(dtype)
498+
assert len(IntervalDtype._cache) == 1

0 commit comments

Comments
 (0)