Skip to content

Commit 32d6d95

Browse files
committed
add in extension dtype registry
1 parent a5c02d5 commit 32d6d95

File tree

5 files changed

+120
-35
lines changed

5 files changed

+120
-35
lines changed

pandas/core/dtypes/common.py

+5-34
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
DatetimeTZDtype, DatetimeTZDtypeType,
1010
PeriodDtype, PeriodDtypeType,
1111
IntervalDtype, IntervalDtypeType,
12-
ExtensionDtype, PandasExtensionDtype)
12+
ExtensionDtype, registry)
1313
from .generic import (ABCCategorical, ABCPeriodIndex,
1414
ABCDatetimeIndex, ABCSeries,
1515
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex,
@@ -1975,39 +1975,10 @@ def pandas_dtype(dtype):
19751975
np.dtype or a pandas dtype
19761976
"""
19771977

1978-
if isinstance(dtype, DatetimeTZDtype):
1979-
return dtype
1980-
elif isinstance(dtype, PeriodDtype):
1981-
return dtype
1982-
elif isinstance(dtype, CategoricalDtype):
1983-
return dtype
1984-
elif isinstance(dtype, IntervalDtype):
1985-
return dtype
1986-
elif isinstance(dtype, string_types):
1987-
try:
1988-
return DatetimeTZDtype.construct_from_string(dtype)
1989-
except TypeError:
1990-
pass
1991-
1992-
if dtype.startswith('period[') or dtype.startswith('Period['):
1993-
# do not parse string like U as period[U]
1994-
try:
1995-
return PeriodDtype.construct_from_string(dtype)
1996-
except TypeError:
1997-
pass
1998-
1999-
elif dtype.startswith('interval') or dtype.startswith('Interval'):
2000-
try:
2001-
return IntervalDtype.construct_from_string(dtype)
2002-
except TypeError:
2003-
pass
2004-
2005-
try:
2006-
return CategoricalDtype.construct_from_string(dtype)
2007-
except TypeError:
2008-
pass
2009-
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
2010-
return dtype
1978+
# registered extension types
1979+
result = registry.find(dtype)
1980+
if result is not None:
1981+
return result
20111982

20121983
try:
20131984
npdtype = np.dtype(dtype)

pandas/core/dtypes/dtypes.py

+80
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,64 @@
22

33
import re
44
import numpy as np
5+
from collections import OrderedDict
56
from pandas import compat
67
from pandas.core.dtypes.generic import ABCIndexClass, ABCCategoricalIndex
78

89
from .base import ExtensionDtype, _DtypeOpsMixin
910

1011

12+
class Registry:
13+
""" class to register our dtypes for inference
14+
15+
We can directly construct dtypes in pandas_dtypes if they are
16+
a type; the registry allows us to register an extension dtype
17+
to try inference from a string or a dtype class
18+
19+
These are tried in order for inference.
20+
"""
21+
dtypes = OrderedDict()
22+
23+
@classmethod
24+
def register(self, dtype, constructor=None):
25+
"""
26+
Parameters
27+
----------
28+
dtype : PandasExtension Dtype
29+
"""
30+
if not issubclass(dtype, PandasExtensionDtype):
31+
raise ValueError("can only register pandas extension dtypes")
32+
33+
if constructor is None:
34+
constructor = dtype.construct_from_string
35+
36+
self.dtypes[dtype] = constructor
37+
38+
def find(self, dtype):
39+
"""
40+
Parameters
41+
----------
42+
dtype : PandasExtensionDtype or string
43+
44+
Returns
45+
-------
46+
return the first matching dtype, otherwise return None
47+
"""
48+
for dtype_type, constructor in self.dtypes.items():
49+
if isinstance(dtype, dtype_type):
50+
return dtype
51+
if isinstance(dtype, compat.string_types):
52+
try:
53+
return constructor(dtype)
54+
except TypeError:
55+
pass
56+
57+
return None
58+
59+
60+
registry = Registry()
61+
62+
1163
class PandasExtensionDtype(_DtypeOpsMixin):
1264
"""
1365
A np.dtype duck-typed class, suitable for holding a custom dtype.
@@ -564,6 +616,17 @@ def construct_from_string(cls, string):
564616
pass
565617
raise TypeError("could not construct PeriodDtype")
566618

619+
@classmethod
620+
def construct_from_string_strict(cls, string):
621+
"""
622+
Strict construction from a string, raise a TypeError if not
623+
possible
624+
"""
625+
if string.startswith('period[') or string.startswith('Period['):
626+
# do not parse string like U as period[U]
627+
return PeriodDtype.construct_from_string(string)
628+
raise TypeError("could not construct PeriodDtype")
629+
567630
def __unicode__(self):
568631
return "period[{freq}]".format(freq=self.freq.freqstr)
569632

@@ -683,6 +746,16 @@ def construct_from_string(cls, string):
683746
msg = "a string needs to be passed, got type {typ}"
684747
raise TypeError(msg.format(typ=type(string)))
685748

749+
@classmethod
750+
def construct_from_string_strict(cls, string):
751+
"""
752+
Strict construction from a string, raise a TypeError if not
753+
possible
754+
"""
755+
if string.startswith('interval') or string.startswith('Interval'):
756+
return IntervalDtype.construct_from_string(string)
757+
raise TypeError("cannot construct IntervalDtype")
758+
686759
def __unicode__(self):
687760
if self.subtype is None:
688761
return "interval"
@@ -723,3 +796,10 @@ def is_dtype(cls, dtype):
723796
else:
724797
return False
725798
return super(IntervalDtype, cls).is_dtype(dtype)
799+
800+
801+
# register the dtypes in search order
802+
registry.register(DatetimeTZDtype)
803+
registry.register(PeriodDtype, PeriodDtype.construct_from_string_strict)
804+
registry.register(IntervalDtype, IntervalDtype.construct_from_string_strict)
805+
registry.register(CategoricalDtype)

pandas/core/series.py

+1
Original file line numberDiff line numberDiff line change
@@ -4060,6 +4060,7 @@ def _try_cast(arr, take_fast_path):
40604060
"Pass the extension array directly.".format(dtype))
40614061
raise ValueError(msg)
40624062

4063+
40634064
elif dtype is not None and raise_cast_failure:
40644065
raise
40654066
else:

pandas/tests/dtypes/test_dtypes.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pandas.compat import string_types
1313
from pandas.core.dtypes.dtypes import (
1414
DatetimeTZDtype, PeriodDtype,
15-
IntervalDtype, CategoricalDtype)
15+
IntervalDtype, CategoricalDtype, registry)
1616
from pandas.core.dtypes.common import (
1717
is_categorical_dtype, is_categorical,
1818
is_datetime64tz_dtype, is_datetimetz,
@@ -767,3 +767,24 @@ def test_update_dtype_errors(self, bad_dtype):
767767
msg = 'a CategoricalDtype must be passed to perform an update, '
768768
with tm.assert_raises_regex(ValueError, msg):
769769
dtype.update_dtype(bad_dtype)
770+
771+
772+
@pytest.mark.parametrize(
773+
'dtype',
774+
[DatetimeTZDtype, CategoricalDtype,
775+
PeriodDtype, IntervalDtype])
776+
def test_registry(dtype):
777+
assert dtype in registry.dtypes
778+
779+
780+
@pytest.mark.parametrize(
781+
'dtype, expected',
782+
[('int64', None),
783+
('interval', IntervalDtype()),
784+
('interval[int64]', IntervalDtype()),
785+
('category', CategoricalDtype()),
786+
('period[D]', PeriodDtype('D')),
787+
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
788+
def test_registry_find(dtype, expected):
789+
790+
assert registry.find(dtype) == expected

pandas/tests/extension/base/constructors.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
import numpy as np
34
import pandas as pd
45
import pandas.util.testing as tm
56
from pandas.core.internals import ExtensionBlock
@@ -45,3 +46,14 @@ def test_series_given_mismatched_index_raises(self, data):
4546
msg = 'Length of passed values is 3, index implies 5'
4647
with tm.assert_raises_regex(ValueError, msg):
4748
pd.Series(data[:3], index=[0, 1, 2, 3, 4])
49+
50+
def test_from_dtype(self, data):
51+
# construct from our dtype & string dtype
52+
dtype = data.dtype
53+
54+
expected = pd.Series(data)
55+
result = pd.Series(np.array(data), dtype=dtype)
56+
self.assert_series_equal(result, expected)
57+
58+
result = pd.Series(np.array(data), dtype=str(dtype))
59+
self.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)