diff --git a/doc/source/api.rst b/doc/source/api.rst index 77d37ec2a7b2e..9c3770a497cf8 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -2559,6 +2559,7 @@ objects. .. autosummary:: :toctree: generated/ + api.extensions.register_extension_dtype api.extensions.register_dataframe_accessor api.extensions.register_series_accessor api.extensions.register_index_accessor diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index fb7af00f61534..0f7e33e007a5b 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -491,6 +491,7 @@ ExtensionType Changes - :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`) - :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`) - :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`). +- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`) .. _whatsnew_0240.api.incompatibilities: diff --git a/pandas/api/extensions/__init__.py b/pandas/api/extensions/__init__.py index 851a63725952a..8a515661920f3 100644 --- a/pandas/api/extensions/__init__.py +++ b/pandas/api/extensions/__init__.py @@ -5,4 +5,6 @@ from pandas.core.algorithms import take # noqa from pandas.core.arrays.base import (ExtensionArray, # noqa ExtensionScalarOpsMixin) -from pandas.core.dtypes.dtypes import ExtensionDtype # noqa +from pandas.core.dtypes.dtypes import ( # noqa + ExtensionDtype, register_extension_dtype +) diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index 5f6a96833c4f8..aebc7a6a04ffc 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -19,7 +19,7 @@ is_list_like) from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.dtypes.dtypes import registry +from pandas.core.dtypes.dtypes import register_extension_dtype from pandas.core.dtypes.missing import isna, notna from pandas.io.formats.printing import ( @@ -614,9 +614,9 @@ def integer_arithmetic_method(self, other): classname = "{}Dtype".format(name) attributes_dict = {'type': getattr(np, dtype), 'name': name} - dtype_type = type(classname, (_IntegerDtype, ), attributes_dict) + dtype_type = register_extension_dtype( + type(classname, (_IntegerDtype, ), attributes_dict) + ) setattr(module, classname, dtype_type) - # register - registry.register(dtype_type) _dtypes[dtype] = dtype_type() diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 1ecb6234ad2d9..7dcdf878231f1 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -127,7 +127,8 @@ class ExtensionDtype(_DtypeOpsMixin): * _is_numeric Optionally one can override construct_array_type for construction - with the name of this dtype via the Registry + with the name of this dtype via the Registry. See + :meth:`pandas.api.extensions.register_extension_dtype`. * construct_array_type @@ -138,6 +139,11 @@ class ExtensionDtype(_DtypeOpsMixin): Methods and properties required by the interface raise ``pandas.errors.AbstractMethodError`` and no ``register`` method is provided for registering virtual subclasses. + + See Also + -------- + pandas.api.extensions.register_extension_dtype + pandas.api.extensions.ExtensionArray """ def __str__(self): diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index f53ccc86fc4ff..4fd77e41a1c67 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -8,6 +8,26 @@ from .base import ExtensionDtype, _DtypeOpsMixin +def register_extension_dtype(cls): + """Class decorator to register an ExtensionType with pandas. + + .. versionadded:: 0.24.0 + + This enables operations like ``.astype(name)`` for the name + of the ExtensionDtype. + + Examples + -------- + >>> from pandas.api.extensions import register_extension_dtype + >>> from pandas.api.extensions import ExtensionDtype + >>> @register_extension_dtype + ... class MyExtensionDtype(ExtensionDtype): + ... pass + """ + registry.register(cls) + return cls + + class Registry(object): """ Registry for dtype inference @@ -17,10 +37,6 @@ class Registry(object): Multiple extension types can be registered. These are tried in order. - - Examples - -------- - registry.register(MyExtensionDtype) """ def __init__(self): self.dtypes = [] @@ -65,9 +81,6 @@ def find(self, dtype): registry = Registry() -# TODO(Extension): remove the second registry once all internal extension -# dtypes are real extension dtypes. -_pandas_registry = Registry() class PandasExtensionDtype(_DtypeOpsMixin): @@ -145,6 +158,7 @@ class CategoricalDtypeType(type): pass +@register_extension_dtype class CategoricalDtype(PandasExtensionDtype, ExtensionDtype): """ Type for categorical data with the categories and orderedness @@ -692,6 +706,7 @@ class IntervalDtypeType(type): pass +@register_extension_dtype class IntervalDtype(PandasExtensionDtype, ExtensionDtype): """ A Interval duck-typed class, suitable for holding an interval @@ -824,8 +839,9 @@ def is_dtype(cls, dtype): return super(IntervalDtype, cls).is_dtype(dtype) -# register the dtypes in search order -registry.register(IntervalDtype) -registry.register(CategoricalDtype) +# TODO(Extension): remove the second registry once all internal extension +# dtypes are real extension dtypes. +_pandas_registry = Registry() + _pandas_registry.register(DatetimeTZDtype) _pandas_registry.register(PeriodDtype) diff --git a/pandas/tests/extension/base/dtype.py b/pandas/tests/extension/base/dtype.py index 2125458e8a0ba..02b7c9527769f 100644 --- a/pandas/tests/extension/base/dtype.py +++ b/pandas/tests/extension/base/dtype.py @@ -1,4 +1,3 @@ -import pytest import numpy as np import pandas as pd @@ -51,10 +50,6 @@ def test_eq_with_numpy_object(self, dtype): def test_array_type(self, data, dtype): assert dtype.construct_array_type() is type(data) - def test_array_type_with_arg(self, data, dtype): - with pytest.raises(NotImplementedError): - dtype.construct_array_type('foo') - def test_check_dtype(self, data): dtype = data.dtype diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 04e855242b5e6..03fdd25826b79 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -105,9 +105,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs): class TestDtype(BaseDecimal, base.BaseDtypeTests): - - def test_array_type_with_arg(self, data, dtype): - assert dtype.construct_array_type() is DecimalArray + pass class TestInterface(BaseDecimal, base.BaseInterfaceTests): diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index b9cc3c431528f..0126d771caf7f 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -116,9 +116,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs): class TestDtype(BaseJSON, base.BaseDtypeTests): - - def test_array_type_with_arg(self, data, dtype): - assert dtype.construct_array_type() is JSONArray + pass class TestInterface(BaseJSON, base.BaseInterfaceTests): diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index b8c73a9efdae8..6c6cf80c16da6 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -77,9 +77,7 @@ def data_for_grouping(): class TestDtype(base.BaseDtypeTests): - - def test_array_type_with_arg(self, data, dtype): - assert dtype.construct_array_type() is Categorical + pass class TestInterface(base.BaseInterfaceTests): diff --git a/pandas/tests/extension/test_integer.py b/pandas/tests/extension/test_integer.py index 50c0e6dd8b347..57e0922a0b7d9 100644 --- a/pandas/tests/extension/test_integer.py +++ b/pandas/tests/extension/test_integer.py @@ -20,7 +20,7 @@ from pandas.tests.extension import base from pandas.core.dtypes.common import is_extension_array_dtype -from pandas.core.arrays import IntegerArray, integer_array +from pandas.core.arrays import integer_array from pandas.core.arrays.integer import ( Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype, UInt8Dtype, UInt16Dtype, UInt32Dtype, UInt64Dtype) @@ -92,9 +92,6 @@ def test_is_dtype_unboxes_dtype(self): # we have multiple dtypes, so skip pass - def test_array_type_with_arg(self, data, dtype): - assert dtype.construct_array_type() is IntegerArray - class TestArithmeticOps(base.BaseArithmeticOpsTests): diff --git a/pandas/tests/extension/test_interval.py b/pandas/tests/extension/test_interval.py index 625619a90ed4c..34b98f590df0d 100644 --- a/pandas/tests/extension/test_interval.py +++ b/pandas/tests/extension/test_interval.py @@ -84,9 +84,7 @@ class BaseInterval(object): class TestDtype(BaseInterval, base.BaseDtypeTests): - - def test_array_type_with_arg(self, data, dtype): - assert dtype.construct_array_type() is IntervalArray + pass class TestCasting(BaseInterval, base.BaseCastingTests):