Skip to content

COMPAT: ensure proper extension dtype's don't pickle the cache #16207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ExtensionDtype(object):
isbuiltin = 0
isnative = 0
_metadata = []
_cache = {}

def __unicode__(self):
return self.name
Expand Down Expand Up @@ -71,6 +72,15 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

def __getstate__(self):
# pickle support; we don't want to pickle the cache
return {k: getattr(self, k, None) for k in self._metadata}

@classmethod
def reset_cache(cls):
""" clear the cache """
cls._cache = {}

@classmethod
def is_dtype(cls, dtype):
""" Return a boolean if the passed type is an actual dtype that
Expand Down Expand Up @@ -110,6 +120,7 @@ class CategoricalDtype(ExtensionDtype):
kind = 'O'
str = '|O08'
base = np.dtype('O')
_metadata = []
_cache = {}

def __new__(cls):
Expand Down Expand Up @@ -408,9 +419,15 @@ def __new__(cls, subtype=None):

if isinstance(subtype, IntervalDtype):
return subtype
elif subtype is None or (isinstance(subtype, compat.string_types) and
subtype == 'interval'):
subtype = None
elif subtype is None:
# we are called as an empty constructor
# generally for pickle compat
u = object.__new__(cls)
u.subtype = None
return u
elif (isinstance(subtype, compat.string_types) and
subtype == 'interval'):
subtype = ''
else:
if isinstance(subtype, compat.string_types):
m = cls._match.search(subtype)
Expand All @@ -423,6 +440,11 @@ def __new__(cls, subtype=None):
except TypeError:
raise ValueError("could not construct IntervalDtype")

if subtype is None:
u = object.__new__(cls)
u.subtype = None
return u

try:
return cls._cache[str(subtype)]
except KeyError:
Expand Down
109 changes: 94 additions & 15 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

class Base(object):

def setup_method(self, method):
self.dtype = self.create()

def test_hash(self):
hash(self.dtype)

Expand All @@ -37,14 +40,38 @@ def test_numpy_informed(self):
assert not np.str_ == self.dtype

def test_pickle(self):
# make sure our cache is NOT pickled

# clear the cache
type(self.dtype).reset_cache()
assert not len(self.dtype._cache)

# force back to the cache
result = tm.round_trip_pickle(self.dtype)
assert not len(self.dtype._cache)
assert result == self.dtype


class TestCategoricalDtype(Base, tm.TestCase):
class TestCategoricalDtype(Base):

def create(self):
return CategoricalDtype()

def test_pickle(self):
# make sure our cache is NOT pickled

# clear the cache
type(self.dtype).reset_cache()
assert not len(self.dtype._cache)

def setUp(self):
self.dtype = CategoricalDtype()
# force back to the cache
result = tm.round_trip_pickle(self.dtype)

# we are a singular object so we are added
# back to the cache upon unpickling
# this is to ensure object identity
assert len(self.dtype._cache) == 1
assert result == self.dtype

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
Expand Down Expand Up @@ -93,10 +120,10 @@ def test_basic(self):
assert not is_categorical(1.0)


class TestDatetimeTZDtype(Base, tm.TestCase):
class TestDatetimeTZDtype(Base):

def setUp(self):
self.dtype = DatetimeTZDtype('ns', 'US/Eastern')
def create(self):
return DatetimeTZDtype('ns', 'US/Eastern')

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
Expand Down Expand Up @@ -209,10 +236,24 @@ def test_empty(self):
str(dt)


class TestPeriodDtype(Base, tm.TestCase):
class TestPeriodDtype(Base):

def setUp(self):
self.dtype = PeriodDtype('D')
def create(self):
return PeriodDtype('D')

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
dtype = self.dtype
dtype2 = PeriodDtype('D')
dtype3 = PeriodDtype(dtype2)
assert dtype == dtype2
assert dtype2 == dtype
assert dtype3 == dtype
assert dtype is dtype2
assert dtype2 is dtype
assert dtype3 is dtype
assert hash(dtype) == hash(dtype2)
assert hash(dtype) == hash(dtype3)

def test_construction(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -338,11 +379,37 @@ def test_not_string(self):
assert not is_string_dtype(PeriodDtype('D'))


class TestIntervalDtype(Base, tm.TestCase):
class TestIntervalDtype(Base):

def create(self):
return IntervalDtype('int64')

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
dtype = self.dtype
dtype2 = IntervalDtype('int64')
dtype3 = IntervalDtype(dtype2)
assert dtype == dtype2
assert dtype2 == dtype
assert dtype3 == dtype
assert dtype is dtype2
assert dtype2 is dtype
assert dtype3 is dtype
assert hash(dtype) == hash(dtype2)
assert hash(dtype) == hash(dtype3)

# TODO: placeholder
def setUp(self):
self.dtype = IntervalDtype('int64')
dtype1 = IntervalDtype('interval')
dtype2 = IntervalDtype(dtype1)
dtype3 = IntervalDtype('interval')
assert dtype2 == dtype1
assert dtype2 == dtype2
assert dtype2 == dtype3
assert dtype2 is dtype1
assert dtype2 is dtype2
assert dtype2 is dtype3
assert hash(dtype2) == hash(dtype1)
assert hash(dtype2) == hash(dtype2)
assert hash(dtype2) == hash(dtype3)

def test_construction(self):
with pytest.raises(ValueError):
Expand All @@ -356,9 +423,9 @@ def test_construction(self):
def test_construction_generic(self):
# generic
i = IntervalDtype('interval')
assert i.subtype is None
assert i.subtype == ''
assert is_interval_dtype(i)
assert str(i) == 'interval'
assert str(i) == 'interval[]'

i = IntervalDtype()
assert i.subtype is None
Expand Down Expand Up @@ -445,3 +512,15 @@ def test_basic_dtype(self):
assert not is_interval_dtype(np.object_)
assert not is_interval_dtype(np.int64)
assert not is_interval_dtype(np.float64)

def test_caching(self):
IntervalDtype.reset_cache()
dtype = IntervalDtype("int64")
assert len(IntervalDtype._cache) == 1

IntervalDtype("interval")
assert len(IntervalDtype._cache) == 2

IntervalDtype.reset_cache()
tm.round_trip_pickle(dtype)
assert len(IntervalDtype._cache) == 0