5
5
"""
6
6
import datetime
7
7
import operator
8
- from typing import Any , Callable , Tuple
8
+ from typing import Any , Callable , Tuple , Union
9
9
10
10
import numpy as np
11
11
34
34
ABCIndexClass ,
35
35
ABCSeries ,
36
36
ABCSparseSeries ,
37
+ ABCTimedeltaArray ,
38
+ ABCTimedeltaIndex ,
37
39
)
38
40
from pandas .core .dtypes .missing import isna , notna
39
41
40
- import pandas as pd
41
42
from pandas ._typing import ArrayLike
42
43
from pandas .core .construction import array , extract_array
43
44
from pandas .core .ops .array_ops import comp_method_OBJECT_ARRAY , define_na_arithmetic_op
@@ -148,6 +149,8 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
148
149
Be careful to call this *after* determining the `name` attribute to be
149
150
attached to the result of the arithmetic operation.
150
151
"""
152
+ from pandas .core .arrays import TimedeltaArray
153
+
151
154
if type (obj ) is datetime .timedelta :
152
155
# GH#22390 cast up to Timedelta to rely on Timedelta
153
156
# implementation; otherwise operation against numeric-dtype
@@ -157,12 +160,10 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
157
160
if isna (obj ):
158
161
# wrapping timedelta64("NaT") in Timedelta returns NaT,
159
162
# which would incorrectly be treated as a datetime-NaT, so
160
- # we broadcast and wrap in a Series
163
+ # we broadcast and wrap in a TimedeltaArray
164
+ obj = obj .astype ("timedelta64[ns]" )
161
165
right = np .broadcast_to (obj , shape )
162
-
163
- # Note: we use Series instead of TimedeltaIndex to avoid having
164
- # to worry about catching NullFrequencyError.
165
- return pd .Series (right )
166
+ return TimedeltaArray (right )
166
167
167
168
# In particular non-nanosecond timedelta64 needs to be cast to
168
169
# nanoseconds, or else we get undesired behavior like
@@ -173,7 +174,7 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
173
174
# GH#22390 Unfortunately we need to special-case right-hand
174
175
# timedelta64 dtypes because numpy casts integer dtypes to
175
176
# timedelta64 when operating with timedelta64
176
- return pd . TimedeltaIndex (obj )
177
+ return TimedeltaArray . _from_sequence (obj )
177
178
return obj
178
179
179
180
@@ -520,13 +521,34 @@ def column_op(a, b):
520
521
return result
521
522
522
523
523
- def dispatch_to_extension_op (op , left , right ):
524
+ def dispatch_to_extension_op (
525
+ op ,
526
+ left : Union [ABCExtensionArray , np .ndarray ],
527
+ right : Any ,
528
+ keep_null_freq : bool = False ,
529
+ ):
524
530
"""
525
531
Assume that left or right is a Series backed by an ExtensionArray,
526
532
apply the operator defined by op.
533
+
534
+ Parameters
535
+ ----------
536
+ op : binary operator
537
+ left : ExtensionArray or np.ndarray
538
+ right : object
539
+ keep_null_freq : bool, default False
540
+ Whether to re-raise a NullFrequencyError unchanged, as opposed to
541
+ catching and raising TypeError.
542
+
543
+ Returns
544
+ -------
545
+ ExtensionArray or np.ndarray
546
+ 2-tuple of these if op is divmod or rdivmod
527
547
"""
548
+ # NB: left and right should already be unboxed, so neither should be
549
+ # a Series or Index.
528
550
529
- if left .dtype .kind in "mM" :
551
+ if left .dtype .kind in "mM" and isinstance ( left , np . ndarray ) :
530
552
# We need to cast datetime64 and timedelta64 ndarrays to
531
553
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in
532
554
# PandasArray as that behaves poorly with e.g. IntegerArray.
@@ -535,15 +557,15 @@ def dispatch_to_extension_op(op, left, right):
535
557
# The op calls will raise TypeError if the op is not defined
536
558
# on the ExtensionArray
537
559
538
- # unbox Series and Index to arrays
539
- new_left = extract_array (left , extract_numpy = True )
540
- new_right = extract_array (right , extract_numpy = True )
541
-
542
560
try :
543
- res_values = op (new_left , new_right )
561
+ res_values = op (left , right )
544
562
except NullFrequencyError :
545
563
# DatetimeIndex and TimedeltaIndex with freq == None raise ValueError
546
564
# on add/sub of integers (or int-like). We re-raise as a TypeError.
565
+ if keep_null_freq :
566
+ # TODO: remove keep_null_freq after Timestamp+int deprecation
567
+ # GH#22535 is enforced
568
+ raise
547
569
raise TypeError (
548
570
"incompatible type for a datetime/timedelta "
549
571
"operation [{name}]" .format (name = op .__name__ )
@@ -615,25 +637,29 @@ def wrapper(left, right):
615
637
if isinstance (right , ABCDataFrame ):
616
638
return NotImplemented
617
639
640
+ keep_null_freq = isinstance (
641
+ right ,
642
+ (ABCDatetimeIndex , ABCDatetimeArray , ABCTimedeltaIndex , ABCTimedeltaArray ),
643
+ )
644
+
618
645
left , right = _align_method_SERIES (left , right )
619
646
res_name = get_op_result_name (left , right )
620
- right = maybe_upcast_for_op (right , left .shape )
621
647
622
- if should_extension_dispatch (left , right ):
623
- result = dispatch_to_extension_op ( op , left , right )
648
+ lvalues = extract_array (left , extract_numpy = True )
649
+ rvalues = extract_array ( right , extract_numpy = True )
624
650
625
- elif is_timedelta64_dtype (right ) or isinstance (
626
- right , (ABCDatetimeArray , ABCDatetimeIndex )
627
- ):
628
- # We should only get here with td64 right with non-scalar values
629
- # for right upcast by maybe_upcast_for_op
630
- assert not isinstance (right , (np .timedelta64 , np .ndarray ))
631
- result = op (left ._values , right )
651
+ rvalues = maybe_upcast_for_op (rvalues , lvalues .shape )
632
652
633
- else :
634
- lvalues = extract_array (left , extract_numpy = True )
635
- rvalues = extract_array (right , extract_numpy = True )
653
+ if should_extension_dispatch (lvalues , rvalues ):
654
+ result = dispatch_to_extension_op (op , lvalues , rvalues , keep_null_freq )
655
+
656
+ elif is_timedelta64_dtype (rvalues ) or isinstance (rvalues , ABCDatetimeArray ):
657
+ # We should only get here with td64 rvalues with non-scalar values
658
+ # for rvalues upcast by maybe_upcast_for_op
659
+ assert not isinstance (rvalues , (np .timedelta64 , np .ndarray ))
660
+ result = dispatch_to_extension_op (op , lvalues , rvalues , keep_null_freq )
636
661
662
+ else :
637
663
with np .errstate (all = "ignore" ):
638
664
result = na_op (lvalues , rvalues )
639
665
@@ -708,25 +734,25 @@ def wrapper(self, other, axis=None):
708
734
if len (self ) != len (other ):
709
735
raise ValueError ("Lengths must match to compare" )
710
736
711
- if should_extension_dispatch (self , other ):
712
- res_values = dispatch_to_extension_op ( op , self , other )
737
+ lvalues = extract_array (self , extract_numpy = True )
738
+ rvalues = extract_array ( other , extract_numpy = True )
713
739
714
- elif is_scalar (other ) and isna (other ):
740
+ if should_extension_dispatch (lvalues , rvalues ):
741
+ res_values = dispatch_to_extension_op (op , lvalues , rvalues )
742
+
743
+ elif is_scalar (rvalues ) and isna (rvalues ):
715
744
# numpy does not like comparisons vs None
716
745
if op is operator .ne :
717
- res_values = np .ones (len (self ), dtype = bool )
746
+ res_values = np .ones (len (lvalues ), dtype = bool )
718
747
else :
719
- res_values = np .zeros (len (self ), dtype = bool )
748
+ res_values = np .zeros (len (lvalues ), dtype = bool )
720
749
721
750
else :
722
- lvalues = extract_array (self , extract_numpy = True )
723
- rvalues = extract_array (other , extract_numpy = True )
724
-
725
751
with np .errstate (all = "ignore" ):
726
752
res_values = na_op (lvalues , rvalues )
727
753
if is_scalar (res_values ):
728
754
raise TypeError (
729
- "Could not compare {typ} type with Series" .format (typ = type (other ))
755
+ "Could not compare {typ} type with Series" .format (typ = type (rvalues ))
730
756
)
731
757
732
758
result = self ._constructor (res_values , index = self .index )
0 commit comments