Skip to content

Commit 56f0a9b

Browse files
committed
ENH: add in extension dtype registry
1 parent f91e28c commit 56f0a9b

23 files changed

+273
-58
lines changed

doc/source/whatsnew/v0.24.0.txt

+9-2
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,22 @@ Other Enhancements
1414
^^^^^^^^^^^^^^^^^^
1515
-
1616
-
17-
-
1817

1918
.. _whatsnew_0240.api_breaking:
2019

2120

2221
Backwards incompatible API changes
2322
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2423

24+
.. _whatsnew_0240.api.extension:
25+
26+
ExtensionType Changes
27+
^^^^^^^^^^^^^^^^^^^^^
28+
29+
- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` and ``.append()``, and attribute ``array_type`` (:issue:`21185`)
30+
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instaniate a registered ``DecimalDtype`` (:issue:`21185`)
31+
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
32+
2533
.. _whatsnew_0240.api.other:
2634

2735
Other API Changes
@@ -177,4 +185,3 @@ Other
177185
-
178186
-
179187
-
180-

pandas/core/algorithms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _reconstruct_data(values, dtype, original):
154154
"""
155155
from pandas import Index
156156
if is_extension_array_dtype(dtype):
157-
pass
157+
values = dtype.array_type._from_sequence(values)
158158
elif is_datetime64tz_dtype(dtype) or is_period_dtype(dtype):
159159
values = Index(original)._shallow_copy(values, name=None)
160160
elif is_bool_dtype(dtype):
@@ -705,7 +705,7 @@ def value_counts(values, sort=True, ascending=False, normalize=False,
705705

706706
else:
707707

708-
if is_categorical_dtype(values) or is_sparse(values):
708+
if is_extension_array_dtype(values) or is_sparse(values):
709709

710710
# handle Categorical and sparse,
711711
result = Series(values)._values.value_counts(dropna=dropna)

pandas/core/arrays/base.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ class ExtensionArray(object):
3636
* isna
3737
* take
3838
* copy
39+
* append
3940
* _concat_same_type
41+
* array_type
4042
4143
An additional method is available to satisfy pandas' internal,
4244
private block API.
@@ -49,6 +51,7 @@ class ExtensionArray(object):
4951
methods:
5052
5153
* fillna
54+
* dropna
5255
* unique
5356
* factorize / _values_for_factorize
5457
* argsort / _values_for_argsort
@@ -82,14 +85,16 @@ class ExtensionArray(object):
8285
# Constructors
8386
# ------------------------------------------------------------------------
8487
@classmethod
85-
def _from_sequence(cls, scalars):
88+
def _from_sequence(cls, scalars, copy=False):
8689
"""Construct a new ExtensionArray from a sequence of scalars.
8790
8891
Parameters
8992
----------
9093
scalars : Sequence
9194
Each element will be an instance of the scalar type for this
9295
array, ``cls.dtype.type``.
96+
copy : boolean, default True
97+
if True, copy the underlying data
9398
Returns
9499
-------
95100
ExtensionArray
@@ -379,6 +384,16 @@ def fillna(self, value=None, method=None, limit=None):
379384
new_values = self.copy()
380385
return new_values
381386

387+
def dropna(self):
388+
""" Return ExtensionArray without NA values
389+
390+
Returns
391+
-------
392+
valid : ExtensionArray
393+
"""
394+
395+
return self[~self.isna()]
396+
382397
def unique(self):
383398
"""Compute the ExtensionArray of unique values.
384399
@@ -567,6 +582,34 @@ def copy(self, deep=False):
567582
"""
568583
raise AbstractMethodError(self)
569584

585+
def append(self, other):
586+
"""
587+
Append a collection of Arrays together
588+
589+
Parameters
590+
----------
591+
other : ExtensionArray or list/tuple of ExtensionArrays
592+
593+
Returns
594+
-------
595+
appended : ExtensionArray
596+
"""
597+
598+
to_concat = [self]
599+
cls = self.__class__
600+
601+
if isinstance(other, (list, tuple)):
602+
to_concat = to_concat + list(other)
603+
else:
604+
to_concat.append(other)
605+
606+
for obj in to_concat:
607+
if not isinstance(obj, cls):
608+
raise TypeError('all inputs must be of type {}'.format(
609+
cls.__name__))
610+
611+
return cls._concat_same_type(to_concat)
612+
570613
# ------------------------------------------------------------------------
571614
# Block-related methods
572615
# ------------------------------------------------------------------------

pandas/core/arrays/categorical.py

+4
Original file line numberDiff line numberDiff line change
@@ -2343,6 +2343,10 @@ def isin(self, values):
23432343
return algorithms.isin(self.codes, code_values)
23442344

23452345

2346+
# inform the Dtype about us
2347+
CategoricalDtype.array_type = Categorical
2348+
2349+
23462350
# The Series.cat accessor
23472351

23482352

pandas/core/dtypes/base.py

+6
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def name(self):
156156
"""
157157
raise AbstractMethodError(self)
158158

159+
@property
160+
def array_type(self):
161+
"""Return the array type associated with this dtype
162+
"""
163+
raise AbstractMethodError(self)
164+
159165
@classmethod
160166
def construct_from_string(cls, string):
161167
"""Attempt to construct this type from a string.

pandas/core/dtypes/cast.py

+5
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,11 @@ def conv(r, dtype):
647647
def astype_nansafe(arr, dtype, copy=True):
648648
""" return a view if copy is False, but
649649
need to be very careful as the result shape could change! """
650+
651+
# dispatch on extension dtype if needed
652+
if is_extension_array_dtype(dtype):
653+
return dtype.array_type._from_sequence(arr, copy=copy)
654+
650655
if not isinstance(dtype, np.dtype):
651656
dtype = pandas_dtype(dtype)
652657

pandas/core/dtypes/common.py

+7-32
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
DatetimeTZDtype, DatetimeTZDtypeType,
1010
PeriodDtype, PeriodDtypeType,
1111
IntervalDtype, IntervalDtypeType,
12-
ExtensionDtype, PandasExtensionDtype)
12+
ExtensionDtype, registry)
1313
from .generic import (ABCCategorical, ABCPeriodIndex,
1414
ABCDatetimeIndex, ABCSeries,
1515
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex,
@@ -1975,38 +1975,13 @@ def pandas_dtype(dtype):
19751975
np.dtype or a pandas dtype
19761976
"""
19771977

1978-
if isinstance(dtype, DatetimeTZDtype):
1979-
return dtype
1980-
elif isinstance(dtype, PeriodDtype):
1981-
return dtype
1982-
elif isinstance(dtype, CategoricalDtype):
1983-
return dtype
1984-
elif isinstance(dtype, IntervalDtype):
1985-
return dtype
1986-
elif isinstance(dtype, string_types):
1987-
try:
1988-
return DatetimeTZDtype.construct_from_string(dtype)
1989-
except TypeError:
1990-
pass
1991-
1992-
if dtype.startswith('period[') or dtype.startswith('Period['):
1993-
# do not parse string like U as period[U]
1994-
try:
1995-
return PeriodDtype.construct_from_string(dtype)
1996-
except TypeError:
1997-
pass
1998-
1999-
elif dtype.startswith('interval') or dtype.startswith('Interval'):
2000-
try:
2001-
return IntervalDtype.construct_from_string(dtype)
2002-
except TypeError:
2003-
pass
1978+
# registered extension types
1979+
result = registry.find(dtype)
1980+
if result is not None:
1981+
return result
20041982

2005-
try:
2006-
return CategoricalDtype.construct_from_string(dtype)
2007-
except TypeError:
2008-
pass
2009-
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
1983+
# un-registered extension types
1984+
if isinstance(dtype, ExtensionDtype):
20101985
return dtype
20111986

20121987
try:

pandas/core/dtypes/dtypes.py

+86
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,70 @@
22

33
import re
44
import numpy as np
5+
from collections import OrderedDict
56
from pandas import compat
67
from pandas.core.dtypes.generic import ABCIndexClass, ABCCategoricalIndex
78

89
from .base import ExtensionDtype, _DtypeOpsMixin
910

1011

12+
class Registry(object):
13+
""" Registry for dtype inference
14+
15+
We can directly construct dtypes in pandas_dtypes if they are
16+
a type; the registry allows us to register an extension dtype
17+
to try inference from a string or a dtype class
18+
19+
These are tried in order for inference.
20+
"""
21+
dtypes = OrderedDict()
22+
23+
@classmethod
24+
def register(self, dtype, constructor=None):
25+
"""
26+
Parameters
27+
----------
28+
dtype : PandasExtension Dtype
29+
"""
30+
if not issubclass(dtype, (PandasExtensionDtype, ExtensionDtype)):
31+
raise ValueError("can only register pandas extension dtypes")
32+
33+
if constructor is None:
34+
constructor = dtype.construct_from_string
35+
36+
self.dtypes[dtype] = constructor
37+
38+
def find(self, dtype):
39+
"""
40+
Parameters
41+
----------
42+
dtype : PandasExtensionDtype or string
43+
44+
Returns
45+
-------
46+
return the first matching dtype, otherwise return None
47+
"""
48+
if not isinstance(dtype, compat.string_types):
49+
dtype_type = dtype
50+
if not isinstance(dtype, type):
51+
dtype_type = type(dtype)
52+
if issubclass(dtype_type, (PandasExtensionDtype, ExtensionDtype)):
53+
return dtype
54+
55+
return None
56+
57+
for dtype_type, constructor in self.dtypes.items():
58+
try:
59+
return constructor(dtype)
60+
except TypeError:
61+
pass
62+
63+
return None
64+
65+
66+
registry = Registry()
67+
68+
1169
class PandasExtensionDtype(_DtypeOpsMixin):
1270
"""
1371
A np.dtype duck-typed class, suitable for holding a custom dtype.
@@ -564,6 +622,17 @@ def construct_from_string(cls, string):
564622
pass
565623
raise TypeError("could not construct PeriodDtype")
566624

625+
@classmethod
626+
def construct_from_string_strict(cls, string):
627+
"""
628+
Strict construction from a string, raise a TypeError if not
629+
possible
630+
"""
631+
if string.startswith('period[') or string.startswith('Period['):
632+
# do not parse string like U as period[U]
633+
return PeriodDtype.construct_from_string(string)
634+
raise TypeError("could not construct PeriodDtype")
635+
567636
def __unicode__(self):
568637
return "period[{freq}]".format(freq=self.freq.freqstr)
569638

@@ -683,6 +752,16 @@ def construct_from_string(cls, string):
683752
msg = "a string needs to be passed, got type {typ}"
684753
raise TypeError(msg.format(typ=type(string)))
685754

755+
@classmethod
756+
def construct_from_string_strict(cls, string):
757+
"""
758+
Strict construction from a string, raise a TypeError if not
759+
possible
760+
"""
761+
if string.startswith('interval') or string.startswith('Interval'):
762+
return IntervalDtype.construct_from_string(string)
763+
raise TypeError("cannot construct IntervalDtype")
764+
686765
def __unicode__(self):
687766
if self.subtype is None:
688767
return "interval"
@@ -723,3 +802,10 @@ def is_dtype(cls, dtype):
723802
else:
724803
return False
725804
return super(IntervalDtype, cls).is_dtype(dtype)
805+
806+
807+
# register the dtypes in search order
808+
registry.register(DatetimeTZDtype)
809+
registry.register(PeriodDtype, PeriodDtype.construct_from_string_strict)
810+
registry.register(IntervalDtype, IntervalDtype.construct_from_string_strict)
811+
registry.register(CategoricalDtype)

pandas/core/internals.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,9 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
633633
return self.make_block(Categorical(self.values, dtype=dtype))
634634

635635
# astype processing
636-
dtype = np.dtype(dtype)
637-
if self.dtype == dtype:
636+
if not is_extension_array_dtype(dtype):
637+
dtype = np.dtype(dtype)
638+
if is_dtype_equal(self.dtype, dtype):
638639
if copy:
639640
return self.copy()
640641
return self
@@ -662,7 +663,13 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
662663

663664
# _astype_nansafe works fine with 1-d only
664665
values = astype_nansafe(values.ravel(), dtype, copy=True)
665-
values = values.reshape(self.shape)
666+
667+
# TODO(extension)
668+
# should we make this attribute?
669+
try:
670+
values = values.reshape(self.shape)
671+
except AttributeError:
672+
pass
666673

667674
newb = make_block(values, placement=self.mgr_locs,
668675
klass=klass)
@@ -3170,6 +3177,10 @@ def get_block_type(values, dtype=None):
31703177
cls = TimeDeltaBlock
31713178
elif issubclass(vtype, np.complexfloating):
31723179
cls = ComplexBlock
3180+
elif is_categorical(values):
3181+
cls = CategoricalBlock
3182+
elif is_extension_array_dtype(values):
3183+
cls = ExtensionBlock
31733184
elif issubclass(vtype, np.datetime64):
31743185
assert not is_datetimetz(values)
31753186
cls = DatetimeBlock
@@ -3179,10 +3190,6 @@ def get_block_type(values, dtype=None):
31793190
cls = IntBlock
31803191
elif dtype == np.bool_:
31813192
cls = BoolBlock
3182-
elif is_categorical(values):
3183-
cls = CategoricalBlock
3184-
elif is_extension_array_dtype(values):
3185-
cls = ExtensionBlock
31863193
else:
31873194
cls = ObjectBlock
31883195
return cls

0 commit comments

Comments
 (0)