Skip to content

Commit 06d3f85

Browse files
TomAugspurgertm9k1
authored andcommitted
BUG/PERF: Avoid listifying in dispatch_to_extension_op (pandas-dev#23155)
This simplifies dispatch_to_extension_op. The remaining logic is simply unboxing Series / Indexes in favor of their underlying arrays. This forced two additional changes 1. Move some logic that IntegerArray relied on down to the IntegerArray ops. Things like handling of 0-dim ndarrays was previously broken on IntegerArray ops, but work with Serires[IntegerArray] 2. Fix pandas handling of 1 ** NA.
1 parent 09b1c6c commit 06d3f85

File tree

12 files changed

+196
-45
lines changed

12 files changed

+196
-45
lines changed

doc/source/extending.rst

+16
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ There are two approaches for providing operator support for your ExtensionArray:
135135
2. Use an operator implementation from pandas that depends on operators that are already defined
136136
on the underlying elements (scalars) of the ExtensionArray.
137137

138+
.. note::
139+
140+
Regardless of the approach, you may want to set ``__array_priority__``
141+
if you want your implementation to be called when involved in binary operations
142+
with NumPy arrays.
143+
138144
For the first approach, you define selected operators, e.g., ``__add__``, ``__le__``, etc. that
139145
you want your ``ExtensionArray`` subclass to support.
140146

@@ -173,6 +179,16 @@ or not that succeeds depends on whether the operation returns a result
173179
that's valid for the ``ExtensionArray``. If an ``ExtensionArray`` cannot
174180
be reconstructed, an ndarray containing the scalars returned instead.
175181

182+
For ease of implementation and consistency with operations between pandas
183+
and NumPy ndarrays, we recommend *not* handling Series and Indexes in your binary ops.
184+
Instead, you should detect these cases and return ``NotImplemented``.
185+
When pandas encounters an operation like ``op(Series, ExtensionArray)``, pandas
186+
will
187+
188+
1. unbox the array from the ``Series`` (roughly ``Series.values``)
189+
2. call ``result = op(values, ExtensionArray)``
190+
3. re-box the result in a ``Series``
191+
176192
.. _extending.extension.testing:
177193

178194
Testing Extension Arrays

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ Numeric
879879
- Bug in :meth:`DataFrame.apply` where, when supplied with a string argument and additional positional or keyword arguments (e.g. ``df.apply('sum', min_count=1)``), a ``TypeError`` was wrongly raised (:issue:`22376`)
880880
- Bug in :meth:`DataFrame.astype` to extension dtype may raise ``AttributeError`` (:issue:`22578`)
881881
- Bug in :class:`DataFrame` with ``timedelta64[ns]`` dtype arithmetic operations with ``ndarray`` with integer dtype incorrectly treating the narray as ``timedelta64[ns]`` dtype (:issue:`23114`)
882+
- Bug in :meth:`Series.rpow` with object dtype ``NaN`` for ``1 ** NA`` instead of ``1`` (:issue:`22922`).
882883

883884
Strings
884885
^^^^^^^

pandas/core/arrays/base.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import operator
1111

12+
from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass
1213
from pandas.errors import AbstractMethodError
1314
from pandas.compat.numpy import function as nv
1415
from pandas.compat import set_function_name, PY3
@@ -109,6 +110,7 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
109110
compatible with the ExtensionArray.
110111
copy : boolean, default False
111112
If True, copy the underlying data.
113+
112114
Returns
113115
-------
114116
ExtensionArray
@@ -724,7 +726,13 @@ def _reduce(self, name, skipna=True, **kwargs):
724726

725727
class ExtensionOpsMixin(object):
726728
"""
727-
A base class for linking the operators to their dunder names
729+
A base class for linking the operators to their dunder names.
730+
731+
.. note::
732+
733+
You may want to set ``__array_priority__`` if you want your
734+
implementation to be called when involved in binary operations
735+
with NumPy arrays.
728736
"""
729737

730738
@classmethod
@@ -761,12 +769,14 @@ def _add_comparison_ops(cls):
761769

762770

763771
class ExtensionScalarOpsMixin(ExtensionOpsMixin):
764-
"""A mixin for defining the arithmetic and logical operations on
765-
an ExtensionArray class, where it is assumed that the underlying objects
766-
have the operators already defined.
772+
"""
773+
A mixin for defining ops on an ExtensionArray.
774+
775+
It is assumed that the underlying scalar objects have the operators
776+
already defined.
767777
768-
Usage
769-
------
778+
Notes
779+
-----
770780
If you have defined a subclass MyExtensionArray(ExtensionArray), then
771781
use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to
772782
get the arithmetic operators. After the definition of MyExtensionArray,
@@ -776,6 +786,12 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin):
776786
MyExtensionArray._add_comparison_ops()
777787
778788
to link the operators to your class.
789+
790+
.. note::
791+
792+
You may want to set ``__array_priority__`` if you want your
793+
implementation to be called when involved in binary operations
794+
with NumPy arrays.
779795
"""
780796

781797
@classmethod
@@ -825,6 +841,11 @@ def convert_values(param):
825841
else: # Assume its an object
826842
ovalues = [param] * len(self)
827843
return ovalues
844+
845+
if isinstance(other, (ABCSeries, ABCIndexClass)):
846+
# rely on pandas to unbox and dispatch to us
847+
return NotImplemented
848+
828849
lvalues = self
829850
rvalues = convert_values(other)
830851

pandas/core/arrays/integer.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import copy
44
import numpy as np
55

6-
from pandas._libs.lib import infer_dtype
6+
7+
from pandas._libs import lib
78
from pandas.util._decorators import cache_readonly
89
from pandas.compat import u, range, string_types
910
from pandas.compat import set_function_name
@@ -171,7 +172,7 @@ def coerce_to_array(values, dtype, mask=None, copy=False):
171172

172173
values = np.array(values, copy=copy)
173174
if is_object_dtype(values):
174-
inferred_type = infer_dtype(values)
175+
inferred_type = lib.infer_dtype(values)
175176
if inferred_type not in ['floating', 'integer',
176177
'mixed-integer', 'mixed-integer-float']:
177178
raise TypeError("{} cannot be converted to an IntegerDtype".format(
@@ -280,6 +281,8 @@ def _coerce_to_ndarray(self):
280281
data[self._mask] = self._na_value
281282
return data
282283

284+
__array_priority__ = 1000 # higher than ndarray so ops dispatch to us
285+
283286
def __array__(self, dtype=None):
284287
"""
285288
the array interface, return my values
@@ -288,12 +291,6 @@ def __array__(self, dtype=None):
288291
return self._coerce_to_ndarray()
289292

290293
def __iter__(self):
291-
"""Iterate over elements of the array.
292-
293-
"""
294-
# This needs to be implemented so that pandas recognizes extension
295-
# arrays as list-like. The default implementation makes successive
296-
# calls to ``__getitem__``, which may be slower than necessary.
297294
for i in range(len(self)):
298295
if self._mask[i]:
299296
yield self.dtype.na_value
@@ -504,13 +501,21 @@ def cmp_method(self, other):
504501

505502
op_name = op.__name__
506503
mask = None
504+
505+
if isinstance(other, (ABCSeries, ABCIndexClass)):
506+
# Rely on pandas to unbox and dispatch to us.
507+
return NotImplemented
508+
507509
if isinstance(other, IntegerArray):
508510
other, mask = other._data, other._mask
511+
509512
elif is_list_like(other):
510513
other = np.asarray(other)
511514
if other.ndim > 0 and len(self) != len(other):
512515
raise ValueError('Lengths must match to compare')
513516

517+
other = lib.item_from_zerodim(other)
518+
514519
# numpy will show a DeprecationWarning on invalid elementwise
515520
# comparisons, this will raise in the future
516521
with warnings.catch_warnings():
@@ -586,14 +591,21 @@ def integer_arithmetic_method(self, other):
586591

587592
op_name = op.__name__
588593
mask = None
594+
589595
if isinstance(other, (ABCSeries, ABCIndexClass)):
590-
other = getattr(other, 'values', other)
596+
# Rely on pandas to unbox and dispatch to us.
597+
return NotImplemented
591598

592-
if isinstance(other, IntegerArray):
593-
other, mask = other._data, other._mask
594-
elif getattr(other, 'ndim', 0) > 1:
599+
if getattr(other, 'ndim', 0) > 1:
595600
raise NotImplementedError(
596601
"can only perform ops with 1-d structures")
602+
603+
if isinstance(other, IntegerArray):
604+
other, mask = other._data, other._mask
605+
606+
elif getattr(other, 'ndim', None) == 0:
607+
other = other.item()
608+
597609
elif is_list_like(other):
598610
other = np.asarray(other)
599611
if not other.ndim:
@@ -612,6 +624,13 @@ def integer_arithmetic_method(self, other):
612624
else:
613625
mask = self._mask | mask
614626

627+
# 1 ** np.nan is 1. So we have to unmask those.
628+
if op_name == 'pow':
629+
mask = np.where(self == 1, False, mask)
630+
631+
elif op_name == 'rpow':
632+
mask = np.where(other == 1, False, mask)
633+
615634
with np.errstate(all='ignore'):
616635
result = op(self._data, other)
617636

pandas/core/arrays/sparse.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -1459,15 +1459,32 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
14591459
'power': 'pow',
14601460
'remainder': 'mod',
14611461
'divide': 'div',
1462+
'equal': 'eq',
1463+
'not_equal': 'ne',
1464+
'less': 'lt',
1465+
'less_equal': 'le',
1466+
'greater': 'gt',
1467+
'greater_equal': 'ge',
14621468
}
1469+
1470+
flipped = {
1471+
'lt': '__gt__',
1472+
'le': '__ge__',
1473+
'gt': '__lt__',
1474+
'ge': '__le__',
1475+
'eq': '__eq__',
1476+
'ne': '__ne__',
1477+
}
1478+
14631479
op_name = ufunc.__name__
14641480
op_name = aliases.get(op_name, op_name)
14651481

14661482
if op_name in special and kwargs.get('out') is None:
14671483
if isinstance(inputs[0], type(self)):
14681484
return getattr(self, '__{}__'.format(op_name))(inputs[1])
14691485
else:
1470-
return getattr(self, '__r{}__'.format(op_name))(inputs[0])
1486+
name = flipped.get(op_name, '__r{}__'.format(op_name))
1487+
return getattr(self, name)(inputs[0])
14711488

14721489
if len(inputs) == 1:
14731490
# No alignment necessary.
@@ -1516,7 +1533,8 @@ def sparse_arithmetic_method(self, other):
15161533
op_name = op.__name__
15171534

15181535
if isinstance(other, (ABCSeries, ABCIndexClass)):
1519-
other = getattr(other, 'values', other)
1536+
# Rely on pandas to dispatch to us.
1537+
return NotImplemented
15201538

15211539
if isinstance(other, SparseArray):
15221540
return _sparse_array_op(self, other, op, op_name)
@@ -1561,10 +1579,11 @@ def cmp_method(self, other):
15611579
op_name = op_name[:-1]
15621580

15631581
if isinstance(other, (ABCSeries, ABCIndexClass)):
1564-
other = getattr(other, 'values', other)
1582+
# Rely on pandas to unbox and dispatch to us.
1583+
return NotImplemented
15651584

15661585
if not is_scalar(other) and not isinstance(other, type(self)):
1567-
# convert list-like to ndarary
1586+
# convert list-like to ndarray
15681587
other = np.asarray(other)
15691588

15701589
if isinstance(other, np.ndarray):

pandas/core/ops.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,13 @@ def masked_arith_op(x, y, op):
862862
# mask is only meaningful for x
863863
result = np.empty(x.size, dtype=x.dtype)
864864
mask = notna(xrav)
865+
866+
# 1 ** np.nan is 1. So we have to unmask those.
867+
if op == pow:
868+
mask = np.where(x == 1, False, mask)
869+
elif op == rpow:
870+
mask = np.where(y == 1, False, mask)
871+
865872
if mask.any():
866873
with np.errstate(all='ignore'):
867874
result[mask] = op(xrav[mask], y)
@@ -1202,29 +1209,16 @@ def dispatch_to_extension_op(op, left, right):
12021209

12031210
# The op calls will raise TypeError if the op is not defined
12041211
# on the ExtensionArray
1205-
# TODO(jreback)
1206-
# we need to listify to avoid ndarray, or non-same-type extension array
1207-
# dispatching
1208-
1209-
if is_extension_array_dtype(left):
1210-
1211-
new_left = left.values
1212-
if isinstance(right, np.ndarray):
1213-
1214-
# handle numpy scalars, this is a PITA
1215-
# TODO(jreback)
1216-
new_right = lib.item_from_zerodim(right)
1217-
if is_scalar(new_right):
1218-
new_right = [new_right]
1219-
new_right = list(new_right)
1220-
elif is_extension_array_dtype(right) and type(left) != type(right):
1221-
new_right = list(right)
1222-
else:
1223-
new_right = right
12241212

1213+
# unbox Series and Index to arrays
1214+
if isinstance(left, (ABCSeries, ABCIndexClass)):
1215+
new_left = left._values
12251216
else:
1217+
new_left = left
12261218

1227-
new_left = list(left.values)
1219+
if isinstance(right, (ABCSeries, ABCIndexClass)):
1220+
new_right = right._values
1221+
else:
12281222
new_right = right
12291223

12301224
res_values = op(new_left, new_right)

0 commit comments

Comments
 (0)