Skip to content

Commit 98ac047

Browse files
committed
API: register_extension_dtype class decorator
Closes #22664
1 parent 0976e12 commit 98ac047

File tree

10 files changed

+27
-26
lines changed

10 files changed

+27
-26
lines changed

doc/source/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -2559,6 +2559,7 @@ objects.
25592559
.. autosummary::
25602560
:toctree: generated/
25612561

2562+
api.extensions.register_extension_dtype
25622563
api.extensions.register_dataframe_accessor
25632564
api.extensions.register_series_accessor
25642565
api.extensions.register_index_accessor

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ ExtensionType Changes
491491
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
492492
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
493493
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
494+
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)
494495

495496
.. _whatsnew_0240.api.incompatibilities:
496497

pandas/api/extensions/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@
55
from pandas.core.algorithms import take # noqa
66
from pandas.core.arrays.base import (ExtensionArray, # noqa
77
ExtensionScalarOpsMixin)
8-
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa
8+
from pandas.core.dtypes.dtypes import ( # noqa
9+
ExtensionDtype, register_extension_dtype
10+
)

pandas/core/dtypes/dtypes.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@
88
from .base import ExtensionDtype, _DtypeOpsMixin
99

1010

11+
def register_extension_dtype(cls):
12+
"""Class decorator to register an ExtensionType with pandas.
13+
14+
This enables operations like ``.astype(name)`` for the name
15+
of the ExtensionDtype.
16+
17+
Examples
18+
--------
19+
>>> from pandas.api.extensions import register_extension_dtype
20+
>>> from pandas.api.extensions import ExtensionDtype
21+
>>> @register_extension_dtype
22+
... class MyExtensionDtype(ExtensionDtype):
23+
... pass
24+
"""
25+
registry.register(cls)
26+
return cls
27+
28+
1129
class Registry(object):
1230
"""
1331
Registry for dtype inference
@@ -17,10 +35,6 @@ class Registry(object):
1735
1836
Multiple extension types can be registered.
1937
These are tried in order.
20-
21-
Examples
22-
--------
23-
registry.register(MyExtensionDtype)
2438
"""
2539
def __init__(self):
2640
self.dtypes = []

pandas/tests/extension/base/dtype.py

-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import numpy as np
32
import pandas as pd
43

@@ -51,10 +50,6 @@ def test_eq_with_numpy_object(self, dtype):
5150
def test_array_type(self, data, dtype):
5251
assert dtype.construct_array_type() is type(data)
5352

54-
def test_array_type_with_arg(self, data, dtype):
55-
with pytest.raises(NotImplementedError):
56-
dtype.construct_array_type('foo')
57-
5853
def test_check_dtype(self, data):
5954
dtype = data.dtype
6055

pandas/tests/extension/decimal/test_decimal.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):
105105

106106

107107
class TestDtype(BaseDecimal, base.BaseDtypeTests):
108-
109-
def test_array_type_with_arg(self, data, dtype):
110-
assert dtype.construct_array_type() is DecimalArray
108+
pass
111109

112110

113111
class TestInterface(BaseDecimal, base.BaseInterfaceTests):

pandas/tests/extension/json/test_json.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):
116116

117117

118118
class TestDtype(BaseJSON, base.BaseDtypeTests):
119-
120-
def test_array_type_with_arg(self, data, dtype):
121-
assert dtype.construct_array_type() is JSONArray
122-
119+
pass
123120

124121
class TestInterface(BaseJSON, base.BaseInterfaceTests):
125122
def test_custom_asserts(self):

pandas/tests/extension/test_categorical.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def data_for_grouping():
7777

7878

7979
class TestDtype(base.BaseDtypeTests):
80-
81-
def test_array_type_with_arg(self, data, dtype):
82-
assert dtype.construct_array_type() is Categorical
80+
pass
8381

8482

8583
class TestInterface(base.BaseInterfaceTests):

pandas/tests/extension/test_integer.py

-3
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ def test_is_dtype_unboxes_dtype(self):
9292
# we have multiple dtypes, so skip
9393
pass
9494

95-
def test_array_type_with_arg(self, data, dtype):
96-
assert dtype.construct_array_type() is IntegerArray
97-
9895

9996
class TestArithmeticOps(base.BaseArithmeticOpsTests):
10097

pandas/tests/extension/test_interval.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ class BaseInterval(object):
8484

8585

8686
class TestDtype(BaseInterval, base.BaseDtypeTests):
87-
88-
def test_array_type_with_arg(self, data, dtype):
89-
assert dtype.construct_array_type() is IntervalArray
87+
pass
9088

9189

9290
class TestCasting(BaseInterval, base.BaseCastingTests):

0 commit comments

Comments
 (0)