Skip to content

API: register_extension_dtype class decorator #22666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2559,6 +2559,7 @@ objects.
.. autosummary::
:toctree: generated/

api.extensions.register_extension_dtype
api.extensions.register_dataframe_accessor
api.extensions.register_series_accessor
api.extensions.register_index_accessor
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ ExtensionType Changes
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)

.. _whatsnew_0240.api.incompatibilities:

Expand Down
4 changes: 3 additions & 1 deletion pandas/api/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
from pandas.core.algorithms import take # noqa
from pandas.core.arrays.base import (ExtensionArray, # noqa
ExtensionScalarOpsMixin)
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa
from pandas.core.dtypes.dtypes import ( # noqa
ExtensionDtype, register_extension_dtype
)
8 changes: 4 additions & 4 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
is_list_like)
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.dtypes import registry
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.core.dtypes.missing import isna, notna

from pandas.io.formats.printing import (
Expand Down Expand Up @@ -614,9 +614,9 @@ def integer_arithmetic_method(self, other):
classname = "{}Dtype".format(name)
attributes_dict = {'type': getattr(np, dtype),
'name': name}
dtype_type = type(classname, (_IntegerDtype, ), attributes_dict)
dtype_type = register_extension_dtype(
type(classname, (_IntegerDtype, ), attributes_dict)
)
setattr(module, classname, dtype_type)

# register
registry.register(dtype_type)
_dtypes[dtype] = dtype_type()
8 changes: 7 additions & 1 deletion pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ class ExtensionDtype(_DtypeOpsMixin):
* _is_numeric

Optionally one can override construct_array_type for construction
with the name of this dtype via the Registry
with the name of this dtype via the Registry. See
:meth:`pandas.api.extensions.register_extension_dtype`.

* construct_array_type

Expand All @@ -138,6 +139,11 @@ class ExtensionDtype(_DtypeOpsMixin):
Methods and properties required by the interface raise
``pandas.errors.AbstractMethodError`` and no ``register`` method is
provided for registering virtual subclasses.

See Also
--------
pandas.api.extensions.register_extension_dtype
pandas.api.extensions.ExtensionArray
"""

def __str__(self):
Expand Down
36 changes: 26 additions & 10 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
from .base import ExtensionDtype, _DtypeOpsMixin


def register_extension_dtype(cls):
"""Class decorator to register an ExtensionType with pandas.

.. versionadded:: 0.24.0

This enables operations like ``.astype(name)`` for the name
of the ExtensionDtype.

Examples
--------
>>> from pandas.api.extensions import register_extension_dtype
>>> from pandas.api.extensions import ExtensionDtype
>>> @register_extension_dtype
... class MyExtensionDtype(ExtensionDtype):
... pass
"""
registry.register(cls)
return cls


class Registry(object):
"""
Registry for dtype inference
Expand All @@ -17,10 +37,6 @@ class Registry(object):

Multiple extension types can be registered.
These are tried in order.

Examples
--------
registry.register(MyExtensionDtype)
"""
def __init__(self):
self.dtypes = []
Expand Down Expand Up @@ -65,9 +81,6 @@ def find(self, dtype):


registry = Registry()
# TODO(Extension): remove the second registry once all internal extension
# dtypes are real extension dtypes.
_pandas_registry = Registry()


class PandasExtensionDtype(_DtypeOpsMixin):
Expand Down Expand Up @@ -145,6 +158,7 @@ class CategoricalDtypeType(type):
pass


@register_extension_dtype
class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
"""
Type for categorical data with the categories and orderedness
Expand Down Expand Up @@ -692,6 +706,7 @@ class IntervalDtypeType(type):
pass


@register_extension_dtype
class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
"""
A Interval duck-typed class, suitable for holding an interval
Expand Down Expand Up @@ -824,8 +839,9 @@ def is_dtype(cls, dtype):
return super(IntervalDtype, cls).is_dtype(dtype)


# register the dtypes in search order
registry.register(IntervalDtype)
registry.register(CategoricalDtype)
# TODO(Extension): remove the second registry once all internal extension
# dtypes are real extension dtypes.
_pandas_registry = Registry()

_pandas_registry.register(DatetimeTZDtype)
_pandas_registry.register(PeriodDtype)
5 changes: 0 additions & 5 deletions pandas/tests/extension/base/dtype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -51,10 +50,6 @@ def test_eq_with_numpy_object(self, dtype):
def test_array_type(self, data, dtype):
assert dtype.construct_array_type() is type(data)

def test_array_type_with_arg(self, data, dtype):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test could never pass or fail correctly. It would always raise a TypeError since the method doesn't take an argument.

with pytest.raises(NotImplementedError):
dtype.construct_array_type('foo')

def test_check_dtype(self, data):
dtype = data.dtype

Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):


class TestDtype(BaseDecimal, base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is DecimalArray
pass


class TestInterface(BaseDecimal, base.BaseInterfaceTests):
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):


class TestDtype(BaseJSON, base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is JSONArray
pass


class TestInterface(BaseJSON, base.BaseInterfaceTests):
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def data_for_grouping():


class TestDtype(base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a legit test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's it testing? That code wasn't ever run on master.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with "That code wasn't ever run on master." ? That the test is never run, or that what it is testing is never used anywhere in the code?

Those tests are certainly running, but it is true that the base class test is not very useful (and also wrong with the string argument).
But I think it is still potentially useful to assert that the dtypes construct_array_type is giving an ExtensionArray class object

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test was overridden by every subclass. Note that we still have https://github.com/pandas-dev/pandas/pull/22666/files/9b646e28656ca3c61ee8a221e16ad74bba2610c3#diff-32e4b328fc01507825a6249caac0cb21R50, which tests this properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we still have .., which tests this properly.

Ah, yes, that is indeed what the other overridden ones were also doing, but then manually.

assert dtype.construct_array_type() is Categorical
pass


class TestInterface(base.BaseInterfaceTests):
Expand Down
5 changes: 1 addition & 4 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pandas.tests.extension import base
from pandas.core.dtypes.common import is_extension_array_dtype

from pandas.core.arrays import IntegerArray, integer_array
from pandas.core.arrays import integer_array
from pandas.core.arrays.integer import (
Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype,
UInt8Dtype, UInt16Dtype, UInt32Dtype, UInt64Dtype)
Expand Down Expand Up @@ -92,9 +92,6 @@ def test_is_dtype_unboxes_dtype(self):
# we have multiple dtypes, so skip
pass

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is IntegerArray


class TestArithmeticOps(base.BaseArithmeticOpsTests):

Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ class BaseInterval(object):


class TestDtype(BaseInterval, base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is IntervalArray
pass


class TestCasting(BaseInterval, base.BaseCastingTests):
Expand Down