diff --git a/doc/source/whatsnew/v0.18.2.txt b/doc/source/whatsnew/v0.18.2.txt index ebae54f292e3c..3e31858bb3683 100644 --- a/doc/source/whatsnew/v0.18.2.txt +++ b/doc/source/whatsnew/v0.18.2.txt @@ -242,7 +242,7 @@ Bug Fixes - Bug in ``Series`` arithmetic raises ``TypeError`` if it contains datetime-like as ``object`` dtype (:issue:`13043`) - +- Bug in extension dtype creation where the created types were not is/identical (:issue:`13285`) - Bug in ``NaT`` - ``Period`` raises ``AttributeError`` (:issue:`13071`) - Bug in ``Period`` addition raises ``TypeError`` if ``Period`` is on right hand side (:issue:`13069`) diff --git a/pandas/tests/types/test_dtypes.py b/pandas/tests/types/test_dtypes.py index 2a9ad30a07805..d48b9baf64777 100644 --- a/pandas/tests/types/test_dtypes.py +++ b/pandas/tests/types/test_dtypes.py @@ -45,6 +45,16 @@ class TestCategoricalDtype(Base, tm.TestCase): def setUp(self): self.dtype = CategoricalDtype() + def test_hash_vs_equality(self): + # make sure that we satisfy is semantics + dtype = self.dtype + dtype2 = CategoricalDtype() + self.assertTrue(dtype == dtype2) + self.assertTrue(dtype2 == dtype) + self.assertTrue(dtype is dtype2) + self.assertTrue(dtype2 is dtype) + self.assertTrue(hash(dtype) == hash(dtype2)) + def test_equality(self): self.assertTrue(is_dtype_equal(self.dtype, 'category')) self.assertTrue(is_dtype_equal(self.dtype, CategoricalDtype())) @@ -88,6 +98,20 @@ class TestDatetimeTZDtype(Base, tm.TestCase): def setUp(self): self.dtype = DatetimeTZDtype('ns', 'US/Eastern') + def test_hash_vs_equality(self): + # make sure that we satisfy is semantics + dtype = self.dtype + dtype2 = DatetimeTZDtype('ns', 'US/Eastern') + dtype3 = DatetimeTZDtype(dtype2) + self.assertTrue(dtype == dtype2) + self.assertTrue(dtype2 == dtype) + self.assertTrue(dtype3 == dtype) + self.assertTrue(dtype is dtype2) + self.assertTrue(dtype2 is dtype) + self.assertTrue(dtype3 is dtype) + self.assertTrue(hash(dtype) == hash(dtype2)) + self.assertTrue(hash(dtype) == hash(dtype3)) + def test_construction(self): self.assertRaises(ValueError, lambda: DatetimeTZDtype('ms', 'US/Eastern')) diff --git a/pandas/types/dtypes.py b/pandas/types/dtypes.py index e6adbc8500117..140d494c3e1b2 100644 --- a/pandas/types/dtypes.py +++ b/pandas/types/dtypes.py @@ -108,6 +108,16 @@ class CategoricalDtype(ExtensionDtype): kind = 'O' str = '|O08' base = np.dtype('O') + _cache = {} + + def __new__(cls): + + try: + return cls._cache[cls.name] + except KeyError: + c = object.__new__(cls) + cls._cache[cls.name] = c + return c def __hash__(self): # make myself hashable @@ -155,9 +165,11 @@ class DatetimeTZDtype(ExtensionDtype): base = np.dtype('M8[ns]') _metadata = ['unit', 'tz'] _match = re.compile("(datetime64|M8)\[(?P.+), (?P.+)\]") + _cache = {} + + def __new__(cls, unit=None, tz=None): + """ Create a new unit if needed, otherwise return from the cache - def __init__(self, unit, tz=None): - """ Parameters ---------- unit : string unit that this represents, currently must be 'ns' @@ -165,28 +177,46 @@ def __init__(self, unit, tz=None): """ if isinstance(unit, DatetimeTZDtype): - self.unit, self.tz = unit.unit, unit.tz - return + unit, tz = unit.unit, unit.tz - if tz is None: + elif unit is None: + # we are called as an empty constructor + # generally for pickle compat + return object.__new__(cls) + + elif tz is None: # we were passed a string that we can construct try: - m = self._match.search(unit) + m = cls._match.search(unit) if m is not None: - self.unit = m.groupdict()['unit'] - self.tz = m.groupdict()['tz'] - return + unit = m.groupdict()['unit'] + tz = m.groupdict()['tz'] except: raise ValueError("could not construct DatetimeTZDtype") + elif isinstance(unit, compat.string_types): + + if unit != 'ns': + raise ValueError("DatetimeTZDtype only supports ns units") + + unit = unit + tz = tz + + if tz is None: raise ValueError("DatetimeTZDtype constructor must have a tz " "supplied") - if unit != 'ns': - raise ValueError("DatetimeTZDtype only supports ns units") - self.unit = unit - self.tz = tz + # set/retrieve from cache + key = (unit, str(tz)) + try: + return cls._cache[key] + except KeyError: + u = object.__new__(cls) + u.unit = unit + u.tz = tz + cls._cache[key] = u + return u @classmethod def construct_from_string(cls, string):