7
7
from pandas ._libs .tslibs import iNaT
8
8
import pandas .compat as compat
9
9
10
+ from pandas .core .dtypes .common import is_datetime64_any_dtype
11
+
10
12
from pandas import (
11
13
DatetimeIndex ,
12
14
Index ,
18
20
Timestamp ,
19
21
isna ,
20
22
)
21
- from pandas .core .arrays import PeriodArray
23
+ from pandas .core .arrays import DatetimeArray , PeriodArray , TimedeltaArray
22
24
from pandas .util import testing as tm
23
25
24
26
@@ -397,7 +399,9 @@ def test_nat_rfloordiv_timedelta(val, expected):
397
399
"value" ,
398
400
[
399
401
DatetimeIndex (["2011-01-01" , "2011-01-02" ], name = "x" ),
400
- DatetimeIndex (["2011-01-01" , "2011-01-02" ], name = "x" ),
402
+ DatetimeIndex (["2011-01-01" , "2011-01-02" ], tz = "US/Eastern" , name = "x" ),
403
+ DatetimeArray ._from_sequence (["2011-01-01" , "2011-01-02" ]),
404
+ DatetimeArray ._from_sequence (["2011-01-01" , "2011-01-02" ], tz = "US/Pacific" ),
401
405
TimedeltaIndex (["1 day" , "2 day" ], name = "x" ),
402
406
],
403
407
)
@@ -406,19 +410,24 @@ def test_nat_arithmetic_index(op_name, value):
406
410
exp_name = "x"
407
411
exp_data = [NaT ] * 2
408
412
409
- if isinstance (value , DatetimeIndex ) and "plus" in op_name :
410
- expected = DatetimeIndex (exp_data , name = exp_name , tz = value .tz )
413
+ if is_datetime64_any_dtype (value . dtype ) and "plus" in op_name :
414
+ expected = DatetimeIndex (exp_data , tz = value .tz , name = exp_name )
411
415
else :
412
416
expected = TimedeltaIndex (exp_data , name = exp_name )
413
417
414
- tm .assert_index_equal (_ops [op_name ](NaT , value ), expected )
418
+ if not isinstance (value , Index ):
419
+ expected = expected .array
420
+
421
+ op = _ops [op_name ]
422
+ result = op (NaT , value )
423
+ tm .assert_equal (result , expected )
415
424
416
425
417
426
@pytest .mark .parametrize (
418
427
"op_name" ,
419
428
["left_plus_right" , "right_plus_left" , "left_minus_right" , "right_minus_left" ],
420
429
)
421
- @pytest .mark .parametrize ("box" , [TimedeltaIndex , Series ])
430
+ @pytest .mark .parametrize ("box" , [TimedeltaIndex , Series , TimedeltaArray . _from_sequence ])
422
431
def test_nat_arithmetic_td64_vector (op_name , box ):
423
432
# see gh-19124
424
433
vec = box (["1 day" , "2 day" ], dtype = "timedelta64[ns]" )
0 commit comments