Skip to content

Commit b7107a9

Browse files
jbrockmendelSeeminSyed
authored andcommitted
REF: simplify should_extension_dispatch, remove dispatch_to_extension_op (pandas-dev#32892)
1 parent 8614993 commit b7107a9

File tree

3 files changed

+106
-141
lines changed

3 files changed

+106
-141
lines changed

pandas/core/ops/__init__.py

+3-63
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@
33
44
This is not a public API.
55
"""
6-
import datetime
76
import operator
8-
from typing import TYPE_CHECKING, Optional, Set, Tuple
7+
from typing import TYPE_CHECKING, Optional, Set
98

109
import numpy as np
1110

12-
from pandas._libs import Timedelta, Timestamp, lib
11+
from pandas._libs import lib
1312
from pandas._libs.ops_dispatch import maybe_dispatch_ufunc_to_dunder_op # noqa:F401
1413
from pandas._typing import ArrayLike, Level
1514
from pandas.util._decorators import Appender
1615

17-
from pandas.core.dtypes.common import is_list_like, is_timedelta64_dtype
16+
from pandas.core.dtypes.common import is_list_like
1817
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
1918
from pandas.core.dtypes.missing import isna
2019

@@ -152,65 +151,6 @@ def _maybe_match_name(a, b):
152151
return None
153152

154153

155-
def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
156-
"""
157-
Cast non-pandas objects to pandas types to unify behavior of arithmetic
158-
and comparison operations.
159-
160-
Parameters
161-
----------
162-
obj: object
163-
shape : tuple[int]
164-
165-
Returns
166-
-------
167-
out : object
168-
169-
Notes
170-
-----
171-
Be careful to call this *after* determining the `name` attribute to be
172-
attached to the result of the arithmetic operation.
173-
"""
174-
from pandas.core.arrays import DatetimeArray, TimedeltaArray
175-
176-
if type(obj) is datetime.timedelta:
177-
# GH#22390 cast up to Timedelta to rely on Timedelta
178-
# implementation; otherwise operation against numeric-dtype
179-
# raises TypeError
180-
return Timedelta(obj)
181-
elif isinstance(obj, np.datetime64):
182-
# GH#28080 numpy casts integer-dtype to datetime64 when doing
183-
# array[int] + datetime64, which we do not allow
184-
if isna(obj):
185-
# Avoid possible ambiguities with pd.NaT
186-
obj = obj.astype("datetime64[ns]")
187-
right = np.broadcast_to(obj, shape)
188-
return DatetimeArray(right)
189-
190-
return Timestamp(obj)
191-
192-
elif isinstance(obj, np.timedelta64):
193-
if isna(obj):
194-
# wrapping timedelta64("NaT") in Timedelta returns NaT,
195-
# which would incorrectly be treated as a datetime-NaT, so
196-
# we broadcast and wrap in a TimedeltaArray
197-
obj = obj.astype("timedelta64[ns]")
198-
right = np.broadcast_to(obj, shape)
199-
return TimedeltaArray(right)
200-
201-
# In particular non-nanosecond timedelta64 needs to be cast to
202-
# nanoseconds, or else we get undesired behavior like
203-
# np.timedelta64(3, 'D') / 2 == np.timedelta64(1, 'D')
204-
return Timedelta(obj)
205-
206-
elif isinstance(obj, np.ndarray) and is_timedelta64_dtype(obj.dtype):
207-
# GH#22390 Unfortunately we need to special-case right-hand
208-
# timedelta64 dtypes because numpy casts integer dtypes to
209-
# timedelta64 when operating with timedelta64
210-
return TimedeltaArray._from_sequence(obj)
211-
return obj
212-
213-
214154
# -----------------------------------------------------------------------------
215155

216156

pandas/core/ops/array_ops.py

+98-25
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
Functions for arithmetic and comparison operations on NumPy arrays and
33
ExtensionArrays.
44
"""
5+
from datetime import timedelta
56
from functools import partial
67
import operator
7-
from typing import Any, Optional
8+
from typing import Any, Optional, Tuple
89

910
import numpy as np
1011

@@ -24,17 +25,11 @@
2425
is_object_dtype,
2526
is_scalar,
2627
)
27-
from pandas.core.dtypes.generic import (
28-
ABCDatetimeArray,
29-
ABCExtensionArray,
30-
ABCIndex,
31-
ABCSeries,
32-
ABCTimedeltaArray,
33-
)
28+
from pandas.core.dtypes.generic import ABCExtensionArray, ABCIndex, ABCSeries
3429
from pandas.core.dtypes.missing import isna, notna
3530

3631
from pandas.core.ops import missing
37-
from pandas.core.ops.dispatch import dispatch_to_extension_op, should_extension_dispatch
32+
from pandas.core.ops.dispatch import should_extension_dispatch
3833
from pandas.core.ops.invalid import invalid_comparison
3934
from pandas.core.ops.roperator import rpow
4035

@@ -199,23 +194,15 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
199194
ndarrray or ExtensionArray
200195
Or a 2-tuple of these in the case of divmod or rdivmod.
201196
"""
202-
from pandas.core.ops import maybe_upcast_for_op
203197

204198
# NB: We assume that extract_array has already been called
205199
# on `left` and `right`.
206-
lvalues = left
207-
rvalues = right
200+
lvalues = maybe_upcast_datetimelike_array(left)
201+
rvalues = maybe_upcast_for_op(right, lvalues.shape)
208202

209-
rvalues = maybe_upcast_for_op(rvalues, lvalues.shape)
210-
211-
if should_extension_dispatch(left, rvalues) or isinstance(
212-
rvalues, (ABCTimedeltaArray, ABCDatetimeArray, Timestamp, Timedelta)
213-
):
214-
# TimedeltaArray, DatetimeArray, and Timestamp are included here
215-
# because they have `freq` attribute which is handled correctly
216-
# by dispatch_to_extension_op.
203+
if should_extension_dispatch(lvalues, rvalues) or isinstance(rvalues, Timedelta):
217204
# Timedelta is included because numexpr will fail on it, see GH#31457
218-
res_values = dispatch_to_extension_op(op, lvalues, rvalues)
205+
res_values = op(lvalues, rvalues)
219206

220207
else:
221208
with np.errstate(all="ignore"):
@@ -287,7 +274,7 @@ def comparison_op(
287274
ndarray or ExtensionArray
288275
"""
289276
# NB: We assume extract_array has already been called on left and right
290-
lvalues = left
277+
lvalues = maybe_upcast_datetimelike_array(left)
291278
rvalues = right
292279

293280
rvalues = lib.item_from_zerodim(rvalues)
@@ -307,7 +294,8 @@ def comparison_op(
307294
)
308295

309296
if should_extension_dispatch(lvalues, rvalues):
310-
res_values = dispatch_to_extension_op(op, lvalues, rvalues)
297+
# Call the method on lvalues
298+
res_values = op(lvalues, rvalues)
311299

312300
elif is_scalar(rvalues) and isna(rvalues):
313301
# numpy does not like comparisons vs None
@@ -406,11 +394,12 @@ def fill_bool(x, left=None):
406394
right = construct_1d_object_array_from_listlike(right)
407395

408396
# NB: We assume extract_array has already been called on left and right
409-
lvalues = left
397+
lvalues = maybe_upcast_datetimelike_array(left)
410398
rvalues = right
411399

412400
if should_extension_dispatch(lvalues, rvalues):
413-
res_values = dispatch_to_extension_op(op, lvalues, rvalues)
401+
# Call the method on lvalues
402+
res_values = op(lvalues, rvalues)
414403

415404
else:
416405
if isinstance(rvalues, np.ndarray):
@@ -453,3 +442,87 @@ def get_array_op(op, str_rep: Optional[str] = None):
453442
return partial(logical_op, op=op)
454443
else:
455444
return partial(arithmetic_op, op=op, str_rep=str_rep)
445+
446+
447+
def maybe_upcast_datetimelike_array(obj: ArrayLike) -> ArrayLike:
448+
"""
449+
If we have an ndarray that is either datetime64 or timedelta64, wrap in EA.
450+
451+
Parameters
452+
----------
453+
obj : ndarray or ExtensionArray
454+
455+
Returns
456+
-------
457+
ndarray or ExtensionArray
458+
"""
459+
if isinstance(obj, np.ndarray):
460+
if obj.dtype.kind == "m":
461+
from pandas.core.arrays import TimedeltaArray
462+
463+
return TimedeltaArray._from_sequence(obj)
464+
if obj.dtype.kind == "M":
465+
from pandas.core.arrays import DatetimeArray
466+
467+
return DatetimeArray._from_sequence(obj)
468+
469+
return obj
470+
471+
472+
def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
473+
"""
474+
Cast non-pandas objects to pandas types to unify behavior of arithmetic
475+
and comparison operations.
476+
477+
Parameters
478+
----------
479+
obj: object
480+
shape : tuple[int]
481+
482+
Returns
483+
-------
484+
out : object
485+
486+
Notes
487+
-----
488+
Be careful to call this *after* determining the `name` attribute to be
489+
attached to the result of the arithmetic operation.
490+
"""
491+
from pandas.core.arrays import DatetimeArray, TimedeltaArray
492+
493+
if type(obj) is timedelta:
494+
# GH#22390 cast up to Timedelta to rely on Timedelta
495+
# implementation; otherwise operation against numeric-dtype
496+
# raises TypeError
497+
return Timedelta(obj)
498+
elif isinstance(obj, np.datetime64):
499+
# GH#28080 numpy casts integer-dtype to datetime64 when doing
500+
# array[int] + datetime64, which we do not allow
501+
if isna(obj):
502+
# Avoid possible ambiguities with pd.NaT
503+
obj = obj.astype("datetime64[ns]")
504+
right = np.broadcast_to(obj, shape)
505+
return DatetimeArray(right)
506+
507+
return Timestamp(obj)
508+
509+
elif isinstance(obj, np.timedelta64):
510+
if isna(obj):
511+
# wrapping timedelta64("NaT") in Timedelta returns NaT,
512+
# which would incorrectly be treated as a datetime-NaT, so
513+
# we broadcast and wrap in a TimedeltaArray
514+
obj = obj.astype("timedelta64[ns]")
515+
right = np.broadcast_to(obj, shape)
516+
return TimedeltaArray(right)
517+
518+
# In particular non-nanosecond timedelta64 needs to be cast to
519+
# nanoseconds, or else we get undesired behavior like
520+
# np.timedelta64(3, 'D') / 2 == np.timedelta64(1, 'D')
521+
return Timedelta(obj)
522+
523+
elif isinstance(obj, np.ndarray) and obj.dtype.kind == "m":
524+
# GH#22390 Unfortunately we need to special-case right-hand
525+
# timedelta64 dtypes because numpy casts integer dtypes to
526+
# timedelta64 when operating with timedelta64
527+
return TimedeltaArray._from_sequence(obj)
528+
return obj

pandas/core/ops/dispatch.py

+5-53
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,31 @@
33
"""
44
from typing import Any
55

6-
import numpy as np
7-
86
from pandas._typing import ArrayLike
97

108
from pandas.core.dtypes.common import (
119
is_datetime64_dtype,
12-
is_extension_array_dtype,
1310
is_integer_dtype,
1411
is_object_dtype,
15-
is_scalar,
1612
is_timedelta64_dtype,
1713
)
18-
from pandas.core.dtypes.generic import ABCSeries
19-
20-
from pandas.core.construction import array
14+
from pandas.core.dtypes.generic import ABCExtensionArray
2115

2216

23-
def should_extension_dispatch(left: ABCSeries, right: Any) -> bool:
17+
def should_extension_dispatch(left: ArrayLike, right: Any) -> bool:
2418
"""
25-
Identify cases where Series operation should use dispatch_to_extension_op.
19+
Identify cases where Series operation should dispatch to ExtensionArray method.
2620
2721
Parameters
2822
----------
29-
left : Series
23+
left : np.ndarray or ExtensionArray
3024
right : object
3125
3226
Returns
3327
-------
3428
bool
3529
"""
36-
if (
37-
is_extension_array_dtype(left.dtype)
38-
or is_datetime64_dtype(left.dtype)
39-
or is_timedelta64_dtype(left.dtype)
40-
):
41-
return True
42-
43-
if not is_scalar(right) and is_extension_array_dtype(right):
44-
# GH#22378 disallow scalar to exclude e.g. "category", "Int64"
45-
return True
46-
47-
return False
30+
return isinstance(left, ABCExtensionArray) or isinstance(right, ABCExtensionArray)
4831

4932

5033
def should_series_dispatch(left, right, op):
@@ -93,34 +76,3 @@ def should_series_dispatch(left, right, op):
9376
return True
9477

9578
return False
96-
97-
98-
def dispatch_to_extension_op(op, left: ArrayLike, right: Any):
99-
"""
100-
Assume that left or right is a Series backed by an ExtensionArray,
101-
apply the operator defined by op.
102-
103-
Parameters
104-
----------
105-
op : binary operator
106-
left : ExtensionArray or np.ndarray
107-
right : object
108-
109-
Returns
110-
-------
111-
ExtensionArray or np.ndarray
112-
2-tuple of these if op is divmod or rdivmod
113-
"""
114-
# NB: left and right should already be unboxed, so neither should be
115-
# a Series or Index.
116-
117-
if left.dtype.kind in "mM" and isinstance(left, np.ndarray):
118-
# We need to cast datetime64 and timedelta64 ndarrays to
119-
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in
120-
# PandasArray as that behaves poorly with e.g. IntegerArray.
121-
left = array(left)
122-
123-
# The op calls will raise TypeError if the op is not defined
124-
# on the ExtensionArray
125-
res_values = op(left, right)
126-
return res_values

0 commit comments

Comments
 (0)