Skip to content

Commit 857515f

Browse files
API: register_extension_dtype class decorator (#22666)
1 parent c040353 commit 857515f

File tree

12 files changed

+47
-37
lines changed

12 files changed

+47
-37
lines changed

doc/source/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -2559,6 +2559,7 @@ objects.
25592559
.. autosummary::
25602560
:toctree: generated/
25612561

2562+
api.extensions.register_extension_dtype
25622563
api.extensions.register_dataframe_accessor
25632564
api.extensions.register_series_accessor
25642565
api.extensions.register_index_accessor

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ ExtensionType Changes
491491
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
492492
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
493493
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
494+
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)
494495

495496
.. _whatsnew_0240.api.incompatibilities:
496497

pandas/api/extensions/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@
55
from pandas.core.algorithms import take # noqa
66
from pandas.core.arrays.base import (ExtensionArray, # noqa
77
ExtensionScalarOpsMixin)
8-
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa
8+
from pandas.core.dtypes.dtypes import ( # noqa
9+
ExtensionDtype, register_extension_dtype
10+
)

pandas/core/arrays/integer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
is_list_like)
2020
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
2121
from pandas.core.dtypes.base import ExtensionDtype
22-
from pandas.core.dtypes.dtypes import registry
22+
from pandas.core.dtypes.dtypes import register_extension_dtype
2323
from pandas.core.dtypes.missing import isna, notna
2424

2525
from pandas.io.formats.printing import (
@@ -614,9 +614,9 @@ def integer_arithmetic_method(self, other):
614614
classname = "{}Dtype".format(name)
615615
attributes_dict = {'type': getattr(np, dtype),
616616
'name': name}
617-
dtype_type = type(classname, (_IntegerDtype, ), attributes_dict)
617+
dtype_type = register_extension_dtype(
618+
type(classname, (_IntegerDtype, ), attributes_dict)
619+
)
618620
setattr(module, classname, dtype_type)
619621

620-
# register
621-
registry.register(dtype_type)
622622
_dtypes[dtype] = dtype_type()

pandas/core/dtypes/base.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ class ExtensionDtype(_DtypeOpsMixin):
127127
* _is_numeric
128128
129129
Optionally one can override construct_array_type for construction
130-
with the name of this dtype via the Registry
130+
with the name of this dtype via the Registry. See
131+
:meth:`pandas.api.extensions.register_extension_dtype`.
131132
132133
* construct_array_type
133134
@@ -138,6 +139,11 @@ class ExtensionDtype(_DtypeOpsMixin):
138139
Methods and properties required by the interface raise
139140
``pandas.errors.AbstractMethodError`` and no ``register`` method is
140141
provided for registering virtual subclasses.
142+
143+
See Also
144+
--------
145+
pandas.api.extensions.register_extension_dtype
146+
pandas.api.extensions.ExtensionArray
141147
"""
142148

143149
def __str__(self):

pandas/core/dtypes/dtypes.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,26 @@
88
from .base import ExtensionDtype, _DtypeOpsMixin
99

1010

11+
def register_extension_dtype(cls):
12+
"""Class decorator to register an ExtensionType with pandas.
13+
14+
.. versionadded:: 0.24.0
15+
16+
This enables operations like ``.astype(name)`` for the name
17+
of the ExtensionDtype.
18+
19+
Examples
20+
--------
21+
>>> from pandas.api.extensions import register_extension_dtype
22+
>>> from pandas.api.extensions import ExtensionDtype
23+
>>> @register_extension_dtype
24+
... class MyExtensionDtype(ExtensionDtype):
25+
... pass
26+
"""
27+
registry.register(cls)
28+
return cls
29+
30+
1131
class Registry(object):
1232
"""
1333
Registry for dtype inference
@@ -17,10 +37,6 @@ class Registry(object):
1737
1838
Multiple extension types can be registered.
1939
These are tried in order.
20-
21-
Examples
22-
--------
23-
registry.register(MyExtensionDtype)
2440
"""
2541
def __init__(self):
2642
self.dtypes = []
@@ -65,9 +81,6 @@ def find(self, dtype):
6581

6682

6783
registry = Registry()
68-
# TODO(Extension): remove the second registry once all internal extension
69-
# dtypes are real extension dtypes.
70-
_pandas_registry = Registry()
7184

7285

7386
class PandasExtensionDtype(_DtypeOpsMixin):
@@ -145,6 +158,7 @@ class CategoricalDtypeType(type):
145158
pass
146159

147160

161+
@register_extension_dtype
148162
class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
149163
"""
150164
Type for categorical data with the categories and orderedness
@@ -692,6 +706,7 @@ class IntervalDtypeType(type):
692706
pass
693707

694708

709+
@register_extension_dtype
695710
class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
696711
"""
697712
A Interval duck-typed class, suitable for holding an interval
@@ -824,8 +839,9 @@ def is_dtype(cls, dtype):
824839
return super(IntervalDtype, cls).is_dtype(dtype)
825840

826841

827-
# register the dtypes in search order
828-
registry.register(IntervalDtype)
829-
registry.register(CategoricalDtype)
842+
# TODO(Extension): remove the second registry once all internal extension
843+
# dtypes are real extension dtypes.
844+
_pandas_registry = Registry()
845+
830846
_pandas_registry.register(DatetimeTZDtype)
831847
_pandas_registry.register(PeriodDtype)

pandas/tests/extension/base/dtype.py

-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import numpy as np
32
import pandas as pd
43

@@ -51,10 +50,6 @@ def test_eq_with_numpy_object(self, dtype):
5150
def test_array_type(self, data, dtype):
5251
assert dtype.construct_array_type() is type(data)
5352

54-
def test_array_type_with_arg(self, data, dtype):
55-
with pytest.raises(NotImplementedError):
56-
dtype.construct_array_type('foo')
57-
5853
def test_check_dtype(self, data):
5954
dtype = data.dtype
6055

pandas/tests/extension/decimal/test_decimal.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):
105105

106106

107107
class TestDtype(BaseDecimal, base.BaseDtypeTests):
108-
109-
def test_array_type_with_arg(self, data, dtype):
110-
assert dtype.construct_array_type() is DecimalArray
108+
pass
111109

112110

113111
class TestInterface(BaseDecimal, base.BaseInterfaceTests):

pandas/tests/extension/json/test_json.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):
116116

117117

118118
class TestDtype(BaseJSON, base.BaseDtypeTests):
119-
120-
def test_array_type_with_arg(self, data, dtype):
121-
assert dtype.construct_array_type() is JSONArray
119+
pass
122120

123121

124122
class TestInterface(BaseJSON, base.BaseInterfaceTests):

pandas/tests/extension/test_categorical.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def data_for_grouping():
7777

7878

7979
class TestDtype(base.BaseDtypeTests):
80-
81-
def test_array_type_with_arg(self, data, dtype):
82-
assert dtype.construct_array_type() is Categorical
80+
pass
8381

8482

8583
class TestInterface(base.BaseInterfaceTests):

pandas/tests/extension/test_integer.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pandas.tests.extension import base
2121
from pandas.core.dtypes.common import is_extension_array_dtype
2222

23-
from pandas.core.arrays import IntegerArray, integer_array
23+
from pandas.core.arrays import integer_array
2424
from pandas.core.arrays.integer import (
2525
Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype,
2626
UInt8Dtype, UInt16Dtype, UInt32Dtype, UInt64Dtype)
@@ -92,9 +92,6 @@ def test_is_dtype_unboxes_dtype(self):
9292
# we have multiple dtypes, so skip
9393
pass
9494

95-
def test_array_type_with_arg(self, data, dtype):
96-
assert dtype.construct_array_type() is IntegerArray
97-
9895

9996
class TestArithmeticOps(base.BaseArithmeticOpsTests):
10097

pandas/tests/extension/test_interval.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ class BaseInterval(object):
8484

8585

8686
class TestDtype(BaseInterval, base.BaseDtypeTests):
87-
88-
def test_array_type_with_arg(self, data, dtype):
89-
assert dtype.construct_array_type() is IntervalArray
87+
pass
9088

9189

9290
class TestCasting(BaseInterval, base.BaseCastingTests):

0 commit comments

Comments
 (0)