Skip to content

Commit 77f8192

Browse files
Dr-Irvvictor
authored and
victor
committed
ENH: Support ExtensionArray operators via a mixin (pandas-dev#21261)
1 parent 9716d05 commit 77f8192

File tree

15 files changed

+460
-33
lines changed

15 files changed

+460
-33
lines changed

doc/source/extending.rst

+56-8
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Extension Types
6161

6262
.. warning::
6363

64-
The :class:`pandas.api.extension.ExtensionDtype` and :class:`pandas.api.extension.ExtensionArray` APIs are new and
64+
The :class:`pandas.api.extensions.ExtensionDtype` and :class:`pandas.api.extensions.ExtensionArray` APIs are new and
6565
experimental. They may change between versions without warning.
6666

6767
Pandas defines an interface for implementing data types and arrays that *extend*
@@ -79,10 +79,10 @@ on :ref:`ecosystem.extensions`.
7979

8080
The interface consists of two classes.
8181

82-
:class:`~pandas.api.extension.ExtensionDtype`
82+
:class:`~pandas.api.extensions.ExtensionDtype`
8383
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
8484

85-
A :class:`pandas.api.extension.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the
85+
A :class:`pandas.api.extensions.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the
8686
data type. Implementors are responsible for a few unique items like the name.
8787

8888
One particularly important item is the ``type`` property. This should be the
@@ -91,7 +91,7 @@ extension array for IP Address data, this might be ``ipaddress.IPv4Address``.
9191

9292
See the `extension dtype source`_ for interface definition.
9393

94-
:class:`~pandas.api.extension.ExtensionArray`
94+
:class:`~pandas.api.extensions.ExtensionArray`
9595
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9696

9797
This class provides all the array-like functionality. ExtensionArrays are
@@ -113,6 +113,54 @@ by some other storage type, like Python lists.
113113
See the `extension array source`_ for the interface definition. The docstrings
114114
and comments contain guidance for properly implementing the interface.
115115

116+
.. _extending.extension.operator:
117+
118+
:class:`~pandas.api.extensions.ExtensionArray` Operator Support
119+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120+
121+
.. versionadded:: 0.24.0
122+
123+
By default, there are no operators defined for the class :class:`~pandas.api.extensions.ExtensionArray`.
124+
There are two approaches for providing operator support for your ExtensionArray:
125+
126+
1. Define each of the operators on your ``ExtensionArray`` subclass.
127+
2. Use an operator implementation from pandas that depends on operators that are already defined
128+
on the underlying elements (scalars) of the ExtensionArray.
129+
130+
For the first approach, you define selected operators, e.g., ``__add__``, ``__le__``, etc. that
131+
you want your ``ExtensionArray`` subclass to support.
132+
133+
The second approach assumes that the underlying elements (i.e., scalar type) of the ``ExtensionArray``
134+
have the individual operators already defined. In other words, if your ``ExtensionArray``
135+
named ``MyExtensionArray`` is implemented so that each element is an instance
136+
of the class ``MyExtensionElement``, then if the operators are defined
137+
for ``MyExtensionElement``, the second approach will automatically
138+
define the operators for ``MyExtensionArray``.
139+
140+
A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpsMixin` supports this second
141+
approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``,
142+
can simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray``,
143+
and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or
144+
:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into
145+
your ``MyExtensionArray`` class, as follows:
146+
147+
.. code-block:: python
148+
149+
class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin):
150+
pass
151+
152+
MyExtensionArray._add_arithmetic_ops()
153+
MyExtensionArray._add_comparison_ops()
154+
155+
Note that since ``pandas`` automatically calls the underlying operator on each
156+
element one-by-one, this might not be as performant as implementing your own
157+
version of the associated operators directly on the ``ExtensionArray``.
158+
159+
.. _extending.extension.testing:
160+
161+
Testing Extension Arrays
162+
^^^^^^^^^^^^^^^^^^^^^^^^
163+
116164
We provide a test suite for ensuring that your extension arrays satisfy the expected
117165
behavior. To use the test suite, you must provide several pytest fixtures and inherit
118166
from the base test class. The required fixtures are found in
@@ -174,11 +222,11 @@ There are 3 constructor properties to be defined:
174222
Following table shows how ``pandas`` data structures define constructor properties by default.
175223

176224
=========================== ======================= =============
177-
Property Attributes ``Series`` ``DataFrame``
225+
Property Attributes ``Series`` ``DataFrame``
178226
=========================== ======================= =============
179-
``_constructor`` ``Series`` ``DataFrame``
180-
``_constructor_sliced`` ``NotImplementedError`` ``Series``
181-
``_constructor_expanddim`` ``DataFrame`` ``Panel``
227+
``_constructor`` ``Series`` ``DataFrame``
228+
``_constructor_sliced`` ``NotImplementedError`` ``Series``
229+
``_constructor_expanddim`` ``DataFrame`` ``Panel``
182230
=========================== ======================= =============
183231

184232
Below example shows how to define ``SubclassedSeries`` and ``SubclassedDataFrame`` overriding constructor properties.

doc/source/whatsnew/v0.24.0.txt

+16
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@ New features
1010

1111
- ``ExcelWriter`` now accepts ``mode`` as a keyword argument, enabling append to existing workbooks when using the ``openpyxl`` engine (:issue:`3441`)
1212

13+
.. _whatsnew_0240.enhancements.extension_array_operators
14+
15+
``ExtensionArray`` operator support
16+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17+
18+
A ``Series`` based on an ``ExtensionArray`` now supports arithmetic and comparison
19+
operators. (:issue:`19577`). There are two approaches for providing operator support for an ``ExtensionArray``:
20+
21+
1. Define each of the operators on your ``ExtensionArray`` subclass.
22+
2. Use an operator implementation from pandas that depends on operators that are already defined
23+
on the underlying elements (scalars) of the ``ExtensionArray``.
24+
25+
See the :ref:`ExtensionArray Operator Support
26+
<extending.extension.operator>` documentation section for details on both
27+
ways of adding operator support.
28+
1329
.. _whatsnew_0240.enhancements.other:
1430

1531
Other Enhancements

pandas/api/extensions/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
register_index_accessor,
44
register_series_accessor)
55
from pandas.core.algorithms import take # noqa
6-
from pandas.core.arrays.base import ExtensionArray # noqa
6+
from pandas.core.arrays.base import (ExtensionArray, # noqa
7+
ExtensionScalarOpsMixin)
78
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa

pandas/conftest.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def observed(request):
8989
'__mul__', '__rmul__',
9090
'__floordiv__', '__rfloordiv__',
9191
'__truediv__', '__rtruediv__',
92-
'__pow__', '__rpow__']
92+
'__pow__', '__rpow__',
93+
'__mod__', '__rmod__']
9394
if not PY3:
9495
_all_arithmetic_operators.extend(['__div__', '__rdiv__'])
9596

@@ -102,6 +103,22 @@ def all_arithmetic_operators(request):
102103
return request.param
103104

104105

106+
@pytest.fixture(params=['__eq__', '__ne__', '__le__',
107+
'__lt__', '__ge__', '__gt__'])
108+
def all_compare_operators(request):
109+
"""
110+
Fixture for dunder names for common compare operations
111+
112+
* >=
113+
* >
114+
* ==
115+
* !=
116+
* <
117+
* <=
118+
"""
119+
return request.param
120+
121+
105122
@pytest.fixture(params=[None, 'gzip', 'bz2', 'zip',
106123
pytest.param('xz', marks=td.skip_if_no_lzma)])
107124
def compression(request):
@@ -320,20 +337,3 @@ def mock():
320337
return importlib.import_module("unittest.mock")
321338
else:
322339
return pytest.importorskip("mock")
323-
324-
325-
@pytest.fixture(params=['__eq__', '__ne__', '__le__',
326-
'__lt__', '__ge__', '__gt__'])
327-
def all_compare_operators(request):
328-
"""
329-
Fixture for dunder names for common compare operations
330-
331-
* >=
332-
* >
333-
* ==
334-
* !=
335-
* <
336-
* <=
337-
"""
338-
339-
return request.param

pandas/core/arrays/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from .base import ExtensionArray # noqa
1+
from .base import (ExtensionArray, # noqa
2+
ExtensionScalarOpsMixin)
23
from .categorical import Categorical # noqa

pandas/core/arrays/base.py

+127
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
"""
88
import numpy as np
99

10+
import operator
11+
1012
from pandas.errors import AbstractMethodError
1113
from pandas.compat.numpy import function as nv
14+
from pandas.compat import set_function_name, PY3
15+
from pandas.core.dtypes.common import is_list_like
16+
from pandas.core import ops
1217

1318
_not_implemented_message = "{} does not implement {}."
1419

@@ -610,3 +615,125 @@ def _ndarray_values(self):
610615
used for interacting with our indexers.
611616
"""
612617
return np.array(self)
618+
619+
620+
class ExtensionOpsMixin(object):
621+
"""
622+
A base class for linking the operators to their dunder names
623+
"""
624+
@classmethod
625+
def _add_arithmetic_ops(cls):
626+
cls.__add__ = cls._create_arithmetic_method(operator.add)
627+
cls.__radd__ = cls._create_arithmetic_method(ops.radd)
628+
cls.__sub__ = cls._create_arithmetic_method(operator.sub)
629+
cls.__rsub__ = cls._create_arithmetic_method(ops.rsub)
630+
cls.__mul__ = cls._create_arithmetic_method(operator.mul)
631+
cls.__rmul__ = cls._create_arithmetic_method(ops.rmul)
632+
cls.__pow__ = cls._create_arithmetic_method(operator.pow)
633+
cls.__rpow__ = cls._create_arithmetic_method(ops.rpow)
634+
cls.__mod__ = cls._create_arithmetic_method(operator.mod)
635+
cls.__rmod__ = cls._create_arithmetic_method(ops.rmod)
636+
cls.__floordiv__ = cls._create_arithmetic_method(operator.floordiv)
637+
cls.__rfloordiv__ = cls._create_arithmetic_method(ops.rfloordiv)
638+
cls.__truediv__ = cls._create_arithmetic_method(operator.truediv)
639+
cls.__rtruediv__ = cls._create_arithmetic_method(ops.rtruediv)
640+
if not PY3:
641+
cls.__div__ = cls._create_arithmetic_method(operator.div)
642+
cls.__rdiv__ = cls._create_arithmetic_method(ops.rdiv)
643+
644+
cls.__divmod__ = cls._create_arithmetic_method(divmod)
645+
cls.__rdivmod__ = cls._create_arithmetic_method(ops.rdivmod)
646+
647+
@classmethod
648+
def _add_comparison_ops(cls):
649+
cls.__eq__ = cls._create_comparison_method(operator.eq)
650+
cls.__ne__ = cls._create_comparison_method(operator.ne)
651+
cls.__lt__ = cls._create_comparison_method(operator.lt)
652+
cls.__gt__ = cls._create_comparison_method(operator.gt)
653+
cls.__le__ = cls._create_comparison_method(operator.le)
654+
cls.__ge__ = cls._create_comparison_method(operator.ge)
655+
656+
657+
class ExtensionScalarOpsMixin(ExtensionOpsMixin):
658+
"""A mixin for defining the arithmetic and logical operations on
659+
an ExtensionArray class, where it is assumed that the underlying objects
660+
have the operators already defined.
661+
662+
Usage
663+
------
664+
If you have defined a subclass MyExtensionArray(ExtensionArray), then
665+
use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to
666+
get the arithmetic operators. After the definition of MyExtensionArray,
667+
insert the lines
668+
669+
MyExtensionArray._add_arithmetic_ops()
670+
MyExtensionArray._add_comparison_ops()
671+
672+
to link the operators to your class.
673+
"""
674+
675+
@classmethod
676+
def _create_method(cls, op, coerce_to_dtype=True):
677+
"""
678+
A class method that returns a method that will correspond to an
679+
operator for an ExtensionArray subclass, by dispatching to the
680+
relevant operator defined on the individual elements of the
681+
ExtensionArray.
682+
683+
Parameters
684+
----------
685+
op : function
686+
An operator that takes arguments op(a, b)
687+
coerce_to_dtype : bool
688+
boolean indicating whether to attempt to convert
689+
the result to the underlying ExtensionArray dtype
690+
(default True)
691+
692+
Returns
693+
-------
694+
A method that can be bound to a method of a class
695+
696+
Example
697+
-------
698+
Given an ExtensionArray subclass called MyExtensionArray, use
699+
700+
>>> __add__ = cls._create_method(operator.add)
701+
702+
in the class definition of MyExtensionArray to create the operator
703+
for addition, that will be based on the operator implementation
704+
of the underlying elements of the ExtensionArray
705+
706+
"""
707+
708+
def _binop(self, other):
709+
def convert_values(param):
710+
if isinstance(param, ExtensionArray) or is_list_like(param):
711+
ovalues = param
712+
else: # Assume its an object
713+
ovalues = [param] * len(self)
714+
return ovalues
715+
lvalues = self
716+
rvalues = convert_values(other)
717+
718+
# If the operator is not defined for the underlying objects,
719+
# a TypeError should be raised
720+
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]
721+
722+
if coerce_to_dtype:
723+
try:
724+
res = self._from_sequence(res)
725+
except TypeError:
726+
pass
727+
728+
return res
729+
730+
op_name = ops._get_op_name(op, True)
731+
return set_function_name(_binop, op_name, cls)
732+
733+
@classmethod
734+
def _create_arithmetic_method(cls, op):
735+
return cls._create_method(op)
736+
737+
@classmethod
738+
def _create_comparison_method(cls, op):
739+
return cls._create_method(op, coerce_to_dtype=False)

pandas/core/ops.py

+31
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
is_bool_dtype,
3434
is_list_like,
3535
is_scalar,
36+
is_extension_array_dtype,
3637
_ensure_object)
3738
from pandas.core.dtypes.cast import (
3839
maybe_upcast_putmask, find_common_type,
@@ -993,6 +994,26 @@ def _construct_divmod_result(left, result, index, name, dtype):
993994
)
994995

995996

997+
def dispatch_to_extension_op(op, left, right):
998+
"""
999+
Assume that left or right is a Series backed by an ExtensionArray,
1000+
apply the operator defined by op.
1001+
"""
1002+
1003+
# The op calls will raise TypeError if the op is not defined
1004+
# on the ExtensionArray
1005+
if is_extension_array_dtype(left):
1006+
res_values = op(left.values, right)
1007+
else:
1008+
# We know that left is not ExtensionArray and is Series and right is
1009+
# ExtensionArray. Want to force ExtensionArray op to get called
1010+
res_values = op(list(left.values), right.values)
1011+
1012+
res_name = get_op_result_name(left, right)
1013+
return left._constructor(res_values, index=left.index,
1014+
name=res_name)
1015+
1016+
9961017
def _arith_method_SERIES(cls, op, special):
9971018
"""
9981019
Wrapper function for Series arithmetic operations, to avoid
@@ -1061,6 +1082,11 @@ def wrapper(left, right):
10611082
raise TypeError("{typ} cannot perform the operation "
10621083
"{op}".format(typ=type(left).__name__, op=str_rep))
10631084

1085+
elif (is_extension_array_dtype(left) or
1086+
(is_extension_array_dtype(right) and
1087+
not is_categorical_dtype(right))):
1088+
return dispatch_to_extension_op(op, left, right)
1089+
10641090
lvalues = left.values
10651091
rvalues = right
10661092
if isinstance(rvalues, ABCSeries):
@@ -1238,6 +1264,11 @@ def wrapper(self, other, axis=None):
12381264
return self._constructor(res_values, index=self.index,
12391265
name=res_name)
12401266

1267+
elif (is_extension_array_dtype(self) or
1268+
(is_extension_array_dtype(other) and
1269+
not is_categorical_dtype(other))):
1270+
return dispatch_to_extension_op(op, self, other)
1271+
12411272
elif isinstance(other, ABCSeries):
12421273
# By this point we have checked that self._indexed_same(other)
12431274
res_values = na_op(self.values, other.values)

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class TestMyDtype(BaseDtypeTests):
4747
from .groupby import BaseGroupbyTests # noqa
4848
from .interface import BaseInterfaceTests # noqa
4949
from .methods import BaseMethodsTests # noqa
50+
from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests # noqa
5051
from .missing import BaseMissingTests # noqa
5152
from .reshaping import BaseReshapingTests # noqa
5253
from .setitem import BaseSetitemTests # noqa

0 commit comments

Comments
 (0)