Skip to content

Commit a67ac2a

Browse files
committed
COMPAT: extension dtypes (DatetimeTZ, Categorical) are now Singleton cached objects
allows for proper is / == comparisons closes #13285
1 parent 9662d91 commit a67ac2a

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed

doc/source/whatsnew/v0.18.2.txt

+1
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ Bug Fixes
246246

247247

248248
- Bug in ``pd.to_datetime()`` when passing invalid datatypes (e.g. bool); will now respect the ``errors`` keyword (:issue:`13176`)
249+
- Bug in extension dtype creation where the created types were not is/identical (:issue:`13285`)
249250

250251
- Bug in ``NaT`` - ``Period`` raises ``AttributeError`` (:issue:`13071`)
251252
- Bug in ``Period`` addition raises ``TypeError`` if ``Period`` is on right hand side (:issue:`13069`)

pandas/tests/types/test_dtypes.py

+24
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ class TestCategoricalDtype(Base, tm.TestCase):
4545
def setUp(self):
4646
self.dtype = CategoricalDtype()
4747

48+
def test_hash_vs_equality(self):
49+
# make sure that we satisfy is semantics
50+
dtype = self.dtype
51+
dtype2 = CategoricalDtype()
52+
self.assertTrue(dtype == dtype2)
53+
self.assertTrue(dtype2 == dtype)
54+
self.assertTrue(dtype is dtype2)
55+
self.assertTrue(dtype2 is dtype)
56+
self.assertTrue(hash(dtype) == hash(dtype2))
57+
4858
def test_equality(self):
4959
self.assertTrue(is_dtype_equal(self.dtype, 'category'))
5060
self.assertTrue(is_dtype_equal(self.dtype, CategoricalDtype()))
@@ -88,6 +98,20 @@ class TestDatetimeTZDtype(Base, tm.TestCase):
8898
def setUp(self):
8999
self.dtype = DatetimeTZDtype('ns', 'US/Eastern')
90100

101+
def test_hash_vs_equality(self):
102+
# make sure that we satisfy is semantics
103+
dtype = self.dtype
104+
dtype2 = DatetimeTZDtype('ns', 'US/Eastern')
105+
dtype3 = DatetimeTZDtype(dtype2)
106+
self.assertTrue(dtype == dtype2)
107+
self.assertTrue(dtype2 == dtype)
108+
self.assertTrue(dtype3 == dtype)
109+
self.assertTrue(dtype is dtype2)
110+
self.assertTrue(dtype2 is dtype)
111+
self.assertTrue(dtype3 is dtype)
112+
self.assertTrue(hash(dtype) == hash(dtype2))
113+
self.assertTrue(hash(dtype) == hash(dtype3))
114+
91115
def test_construction(self):
92116
self.assertRaises(ValueError,
93117
lambda: DatetimeTZDtype('ms', 'US/Eastern'))

pandas/types/dtypes.py

+43-13
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ class CategoricalDtype(ExtensionDtype):
108108
kind = 'O'
109109
str = '|O08'
110110
base = np.dtype('O')
111+
_cache = {}
112+
113+
def __new__(cls):
114+
115+
try:
116+
return cls._cache[cls.name]
117+
except KeyError:
118+
c = object.__new__(cls)
119+
cls._cache[cls.name] = c
120+
return c
111121

112122
def __hash__(self):
113123
# make myself hashable
@@ -155,38 +165,58 @@ class DatetimeTZDtype(ExtensionDtype):
155165
base = np.dtype('M8[ns]')
156166
_metadata = ['unit', 'tz']
157167
_match = re.compile("(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
168+
_cache = {}
169+
170+
def __new__(cls, unit=None, tz=None):
171+
""" Create a new unit if needed, otherwise return from the cache
158172
159-
def __init__(self, unit, tz=None):
160-
"""
161173
Parameters
162174
----------
163175
unit : string unit that this represents, currently must be 'ns'
164176
tz : string tz that this represents
165177
"""
166178

167179
if isinstance(unit, DatetimeTZDtype):
168-
self.unit, self.tz = unit.unit, unit.tz
169-
return
180+
unit, tz = unit.unit, unit.tz
170181

171-
if tz is None:
182+
elif unit is None:
183+
# we are called as an empty constructor
184+
# generally for pickle compat
185+
return object.__new__(cls)
186+
187+
elif tz is None:
172188

173189
# we were passed a string that we can construct
174190
try:
175-
m = self._match.search(unit)
191+
m = cls._match.search(unit)
176192
if m is not None:
177-
self.unit = m.groupdict()['unit']
178-
self.tz = m.groupdict()['tz']
179-
return
193+
unit = m.groupdict()['unit']
194+
tz = m.groupdict()['tz']
180195
except:
181196
raise ValueError("could not construct DatetimeTZDtype")
182197

198+
elif isinstance(unit, compat.string_types):
199+
200+
if unit != 'ns':
201+
raise ValueError("DatetimeTZDtype only supports ns units")
202+
203+
unit = unit
204+
tz = tz
205+
206+
if tz is None:
183207
raise ValueError("DatetimeTZDtype constructor must have a tz "
184208
"supplied")
185209

186-
if unit != 'ns':
187-
raise ValueError("DatetimeTZDtype only supports ns units")
188-
self.unit = unit
189-
self.tz = tz
210+
# set/retrieve from cache
211+
key = (unit, str(tz))
212+
try:
213+
return cls._cache[key]
214+
except KeyError:
215+
u = object.__new__(cls)
216+
u.unit = unit
217+
u.tz = tz
218+
cls._cache[key] = u
219+
return u
190220

191221
@classmethod
192222
def construct_from_string(cls, string):

0 commit comments

Comments
 (0)