Skip to content

Commit 55325b2

Browse files
jrebackpcluo
authored andcommitted
COMPAT: ensure proper extension dtype's don't pickle the cache (pandas-dev#16207)
xref pandas-dev#16201
1 parent edb67fe commit 55325b2

File tree

2 files changed

+119
-18
lines changed

2 files changed

+119
-18
lines changed

pandas/core/dtypes/dtypes.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ExtensionDtype(object):
2424
isbuiltin = 0
2525
isnative = 0
2626
_metadata = []
27+
_cache = {}
2728

2829
def __unicode__(self):
2930
return self.name
@@ -71,6 +72,15 @@ def __eq__(self, other):
7172
def __ne__(self, other):
7273
return not self.__eq__(other)
7374

75+
def __getstate__(self):
76+
# pickle support; we don't want to pickle the cache
77+
return {k: getattr(self, k, None) for k in self._metadata}
78+
79+
@classmethod
80+
def reset_cache(cls):
81+
""" clear the cache """
82+
cls._cache = {}
83+
7484
@classmethod
7585
def is_dtype(cls, dtype):
7686
""" Return a boolean if the passed type is an actual dtype that
@@ -110,6 +120,7 @@ class CategoricalDtype(ExtensionDtype):
110120
kind = 'O'
111121
str = '|O08'
112122
base = np.dtype('O')
123+
_metadata = []
113124
_cache = {}
114125

115126
def __new__(cls):
@@ -408,9 +419,15 @@ def __new__(cls, subtype=None):
408419

409420
if isinstance(subtype, IntervalDtype):
410421
return subtype
411-
elif subtype is None or (isinstance(subtype, compat.string_types) and
412-
subtype == 'interval'):
413-
subtype = None
422+
elif subtype is None:
423+
# we are called as an empty constructor
424+
# generally for pickle compat
425+
u = object.__new__(cls)
426+
u.subtype = None
427+
return u
428+
elif (isinstance(subtype, compat.string_types) and
429+
subtype == 'interval'):
430+
subtype = ''
414431
else:
415432
if isinstance(subtype, compat.string_types):
416433
m = cls._match.search(subtype)
@@ -423,6 +440,11 @@ def __new__(cls, subtype=None):
423440
except TypeError:
424441
raise ValueError("could not construct IntervalDtype")
425442

443+
if subtype is None:
444+
u = object.__new__(cls)
445+
u.subtype = None
446+
return u
447+
426448
try:
427449
return cls._cache[str(subtype)]
428450
except KeyError:

pandas/tests/dtypes/test_dtypes.py

+94-15
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
class Base(object):
2525

26+
def setup_method(self, method):
27+
self.dtype = self.create()
28+
2629
def test_hash(self):
2730
hash(self.dtype)
2831

@@ -37,14 +40,38 @@ def test_numpy_informed(self):
3740
assert not np.str_ == self.dtype
3841

3942
def test_pickle(self):
43+
# make sure our cache is NOT pickled
44+
45+
# clear the cache
46+
type(self.dtype).reset_cache()
47+
assert not len(self.dtype._cache)
48+
49+
# force back to the cache
4050
result = tm.round_trip_pickle(self.dtype)
51+
assert not len(self.dtype._cache)
4152
assert result == self.dtype
4253

4354

44-
class TestCategoricalDtype(Base, tm.TestCase):
55+
class TestCategoricalDtype(Base):
56+
57+
def create(self):
58+
return CategoricalDtype()
59+
60+
def test_pickle(self):
61+
# make sure our cache is NOT pickled
62+
63+
# clear the cache
64+
type(self.dtype).reset_cache()
65+
assert not len(self.dtype._cache)
4566

46-
def setUp(self):
47-
self.dtype = CategoricalDtype()
67+
# force back to the cache
68+
result = tm.round_trip_pickle(self.dtype)
69+
70+
# we are a singular object so we are added
71+
# back to the cache upon unpickling
72+
# this is to ensure object identity
73+
assert len(self.dtype._cache) == 1
74+
assert result == self.dtype
4875

4976
def test_hash_vs_equality(self):
5077
# make sure that we satisfy is semantics
@@ -93,10 +120,10 @@ def test_basic(self):
93120
assert not is_categorical(1.0)
94121

95122

96-
class TestDatetimeTZDtype(Base, tm.TestCase):
123+
class TestDatetimeTZDtype(Base):
97124

98-
def setUp(self):
99-
self.dtype = DatetimeTZDtype('ns', 'US/Eastern')
125+
def create(self):
126+
return DatetimeTZDtype('ns', 'US/Eastern')
100127

101128
def test_hash_vs_equality(self):
102129
# make sure that we satisfy is semantics
@@ -209,10 +236,24 @@ def test_empty(self):
209236
str(dt)
210237

211238

212-
class TestPeriodDtype(Base, tm.TestCase):
239+
class TestPeriodDtype(Base):
213240

214-
def setUp(self):
215-
self.dtype = PeriodDtype('D')
241+
def create(self):
242+
return PeriodDtype('D')
243+
244+
def test_hash_vs_equality(self):
245+
# make sure that we satisfy is semantics
246+
dtype = self.dtype
247+
dtype2 = PeriodDtype('D')
248+
dtype3 = PeriodDtype(dtype2)
249+
assert dtype == dtype2
250+
assert dtype2 == dtype
251+
assert dtype3 == dtype
252+
assert dtype is dtype2
253+
assert dtype2 is dtype
254+
assert dtype3 is dtype
255+
assert hash(dtype) == hash(dtype2)
256+
assert hash(dtype) == hash(dtype3)
216257

217258
def test_construction(self):
218259
with pytest.raises(ValueError):
@@ -338,11 +379,37 @@ def test_not_string(self):
338379
assert not is_string_dtype(PeriodDtype('D'))
339380

340381

341-
class TestIntervalDtype(Base, tm.TestCase):
382+
class TestIntervalDtype(Base):
383+
384+
def create(self):
385+
return IntervalDtype('int64')
386+
387+
def test_hash_vs_equality(self):
388+
# make sure that we satisfy is semantics
389+
dtype = self.dtype
390+
dtype2 = IntervalDtype('int64')
391+
dtype3 = IntervalDtype(dtype2)
392+
assert dtype == dtype2
393+
assert dtype2 == dtype
394+
assert dtype3 == dtype
395+
assert dtype is dtype2
396+
assert dtype2 is dtype
397+
assert dtype3 is dtype
398+
assert hash(dtype) == hash(dtype2)
399+
assert hash(dtype) == hash(dtype3)
342400

343-
# TODO: placeholder
344-
def setUp(self):
345-
self.dtype = IntervalDtype('int64')
401+
dtype1 = IntervalDtype('interval')
402+
dtype2 = IntervalDtype(dtype1)
403+
dtype3 = IntervalDtype('interval')
404+
assert dtype2 == dtype1
405+
assert dtype2 == dtype2
406+
assert dtype2 == dtype3
407+
assert dtype2 is dtype1
408+
assert dtype2 is dtype2
409+
assert dtype2 is dtype3
410+
assert hash(dtype2) == hash(dtype1)
411+
assert hash(dtype2) == hash(dtype2)
412+
assert hash(dtype2) == hash(dtype3)
346413

347414
def test_construction(self):
348415
with pytest.raises(ValueError):
@@ -356,9 +423,9 @@ def test_construction(self):
356423
def test_construction_generic(self):
357424
# generic
358425
i = IntervalDtype('interval')
359-
assert i.subtype is None
426+
assert i.subtype == ''
360427
assert is_interval_dtype(i)
361-
assert str(i) == 'interval'
428+
assert str(i) == 'interval[]'
362429

363430
i = IntervalDtype()
364431
assert i.subtype is None
@@ -445,3 +512,15 @@ def test_basic_dtype(self):
445512
assert not is_interval_dtype(np.object_)
446513
assert not is_interval_dtype(np.int64)
447514
assert not is_interval_dtype(np.float64)
515+
516+
def test_caching(self):
517+
IntervalDtype.reset_cache()
518+
dtype = IntervalDtype("int64")
519+
assert len(IntervalDtype._cache) == 1
520+
521+
IntervalDtype("interval")
522+
assert len(IntervalDtype._cache) == 2
523+
524+
IntervalDtype.reset_cache()
525+
tm.round_trip_pickle(dtype)
526+
assert len(IntervalDtype._cache) == 0

0 commit comments

Comments
 (0)