@@ -454,8 +454,9 @@ def test_tz_dtype_matches(self):
454
454
455
455
456
456
class TestReductions :
457
- @pytest .mark .parametrize ("tz" , [None , "US/Central" ])
458
- def test_min_max (self , tz ):
457
+ @pytest .fixture
458
+ def arr1d (self , tz_naive_fixture ):
459
+ tz = tz_naive_fixture
459
460
dtype = DatetimeTZDtype (tz = tz ) if tz is not None else np .dtype ("M8[ns]" )
460
461
arr = DatetimeArray ._from_sequence (
461
462
[
@@ -468,6 +469,11 @@ def test_min_max(self, tz):
468
469
],
469
470
dtype = dtype ,
470
471
)
472
+ return arr
473
+
474
+ def test_min_max (self , arr1d ):
475
+ arr = arr1d
476
+ tz = arr .tz
471
477
472
478
result = arr .min ()
473
479
expected = pd .Timestamp ("2000-01-02" , tz = tz )
@@ -493,3 +499,70 @@ def test_min_max_empty(self, skipna, tz):
493
499
494
500
result = arr .max (skipna = skipna )
495
501
assert result is pd .NaT
502
+
503
+ @pytest .mark .parametrize ("tz" , [None , "US/Central" ])
504
+ @pytest .mark .parametrize ("skipna" , [True , False ])
505
+ def test_median_empty (self , skipna , tz ):
506
+ dtype = DatetimeTZDtype (tz = tz ) if tz is not None else np .dtype ("M8[ns]" )
507
+ arr = DatetimeArray ._from_sequence ([], dtype = dtype )
508
+ result = arr .median (skipna = skipna )
509
+ assert result is pd .NaT
510
+
511
+ arr = arr .reshape (0 , 3 )
512
+ result = arr .median (axis = 0 , skipna = skipna )
513
+ expected = type (arr )._from_sequence ([pd .NaT , pd .NaT , pd .NaT ], dtype = arr .dtype )
514
+ tm .assert_equal (result , expected )
515
+
516
+ result = arr .median (axis = 1 , skipna = skipna )
517
+ expected = type (arr )._from_sequence ([pd .NaT ], dtype = arr .dtype )
518
+ tm .assert_equal (result , expected )
519
+
520
+ def test_median (self , arr1d ):
521
+ arr = arr1d
522
+
523
+ result = arr .median ()
524
+ assert result == arr [0 ]
525
+ result = arr .median (skipna = False )
526
+ assert result is pd .NaT
527
+
528
+ result = arr .dropna ().median (skipna = False )
529
+ assert result == arr [0 ]
530
+
531
+ result = arr .median (axis = 0 )
532
+ assert result == arr [0 ]
533
+
534
+ def test_median_axis (self , arr1d ):
535
+ arr = arr1d
536
+ assert arr .median (axis = 0 ) == arr .median ()
537
+ assert arr .median (axis = 0 , skipna = False ) is pd .NaT
538
+
539
+ msg = r"abs\(axis\) must be less than ndim"
540
+ with pytest .raises (ValueError , match = msg ):
541
+ arr .median (axis = 1 )
542
+
543
+ @pytest .mark .filterwarnings ("ignore:All-NaN slice encountered:RuntimeWarning" )
544
+ def test_median_2d (self , arr1d ):
545
+ arr = arr1d .reshape (1 , - 1 )
546
+
547
+ # axis = None
548
+ assert arr .median () == arr1d .median ()
549
+ assert arr .median (skipna = False ) is pd .NaT
550
+
551
+ # axis = 0
552
+ result = arr .median (axis = 0 )
553
+ expected = arr1d
554
+ tm .assert_equal (result , expected )
555
+
556
+ # Since column 3 is all-NaT, we get NaT there with or without skipna
557
+ result = arr .median (axis = 0 , skipna = False )
558
+ expected = arr1d
559
+ tm .assert_equal (result , expected )
560
+
561
+ # axis = 1
562
+ result = arr .median (axis = 1 )
563
+ expected = type (arr )._from_sequence ([arr1d .median ()])
564
+ tm .assert_equal (result , expected )
565
+
566
+ result = arr .median (axis = 1 , skipna = False )
567
+ expected = type (arr )._from_sequence ([pd .NaT ], dtype = arr .dtype )
568
+ tm .assert_equal (result , expected )
0 commit comments