@@ -367,6 +367,32 @@ def _wrap_results(result, dtype: np.dtype, fill_value=None):
367
367
return result
368
368
369
369
370
+ def _datetimelike_compat (func ):
371
+ """
372
+ If we have datetime64 or timedelta64 values, ensure we have a correct
373
+ mask before calling the wrapped function, then cast back afterwards.
374
+ """
375
+
376
+ @functools .wraps (func )
377
+ def new_func (values , * , axis = None , skipna = True , mask = None , ** kwargs ):
378
+ orig_values = values
379
+
380
+ datetimelike = values .dtype .kind in ["m" , "M" ]
381
+ if datetimelike and mask is None :
382
+ mask = isna (values )
383
+
384
+ result = func (values , axis = axis , skipna = skipna , mask = mask , ** kwargs )
385
+
386
+ if datetimelike :
387
+ result = _wrap_results (result , orig_values .dtype , fill_value = iNaT )
388
+ if not skipna :
389
+ result = _mask_datetimelike_result (result , axis , mask , orig_values )
390
+
391
+ return result
392
+
393
+ return new_func
394
+
395
+
370
396
def _na_for_min_count (
371
397
values : np .ndarray , axis : Optional [int ]
372
398
) -> Union [Scalar , np .ndarray ]:
@@ -480,6 +506,7 @@ def nanall(
480
506
481
507
482
508
@disallow ("M8" )
509
+ @_datetimelike_compat
483
510
def nansum (
484
511
values : np .ndarray ,
485
512
* ,
@@ -511,25 +538,18 @@ def nansum(
511
538
>>> nanops.nansum(s)
512
539
3.0
513
540
"""
514
- orig_values = values
515
-
516
541
values , mask , dtype , dtype_max , _ = _get_values (
517
542
values , skipna , fill_value = 0 , mask = mask
518
543
)
519
544
dtype_sum = dtype_max
520
- datetimelike = False
521
545
if is_float_dtype (dtype ):
522
546
dtype_sum = dtype
523
547
elif is_timedelta64_dtype (dtype ):
524
- datetimelike = True
525
548
dtype_sum = np .float64
526
549
527
550
the_sum = values .sum (axis , dtype = dtype_sum )
528
551
the_sum = _maybe_null_out (the_sum , axis , mask , values .shape , min_count = min_count )
529
552
530
- the_sum = _wrap_results (the_sum , dtype )
531
- if datetimelike and not skipna :
532
- the_sum = _mask_datetimelike_result (the_sum , axis , mask , orig_values )
533
553
return the_sum
534
554
535
555
@@ -552,6 +572,7 @@ def _mask_datetimelike_result(
552
572
553
573
@disallow (PeriodDtype )
554
574
@bottleneck_switch ()
575
+ @_datetimelike_compat
555
576
def nanmean (
556
577
values : np .ndarray ,
557
578
* ,
@@ -583,18 +604,14 @@ def nanmean(
583
604
>>> nanops.nanmean(s)
584
605
1.5
585
606
"""
586
- orig_values = values
587
-
588
607
values , mask , dtype , dtype_max , _ = _get_values (
589
608
values , skipna , fill_value = 0 , mask = mask
590
609
)
591
610
dtype_sum = dtype_max
592
611
dtype_count = np .float64
593
612
594
613
# not using needs_i8_conversion because that includes period
595
- datetimelike = False
596
614
if dtype .kind in ["m" , "M" ]:
597
- datetimelike = True
598
615
dtype_sum = np .float64
599
616
elif is_integer_dtype (dtype ):
600
617
dtype_sum = np .float64
@@ -616,9 +633,6 @@ def nanmean(
616
633
else :
617
634
the_mean = the_sum / count if count > 0 else np .nan
618
635
619
- the_mean = _wrap_results (the_mean , dtype )
620
- if datetimelike and not skipna :
621
- the_mean = _mask_datetimelike_result (the_mean , axis , mask , orig_values )
622
636
return the_mean
623
637
624
638
@@ -875,7 +889,7 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None):
875
889
# precision as the original values array.
876
890
if is_float_dtype (dtype ):
877
891
result = result .astype (dtype )
878
- return _wrap_results ( result , values . dtype )
892
+ return result
879
893
880
894
881
895
@disallow ("M8" , "m8" )
@@ -930,6 +944,7 @@ def nansem(
930
944
931
945
def _nanminmax (meth , fill_value_typ ):
932
946
@bottleneck_switch (name = "nan" + meth )
947
+ @_datetimelike_compat
933
948
def reduction (
934
949
values : np .ndarray ,
935
950
* ,
@@ -938,13 +953,10 @@ def reduction(
938
953
mask : Optional [np .ndarray ] = None ,
939
954
) -> Dtype :
940
955
941
- orig_values = values
942
956
values , mask , dtype , dtype_max , fill_value = _get_values (
943
957
values , skipna , fill_value_typ = fill_value_typ , mask = mask
944
958
)
945
959
946
- datetimelike = orig_values .dtype .kind in ["m" , "M" ]
947
-
948
960
if (axis is not None and values .shape [axis ] == 0 ) or values .size == 0 :
949
961
try :
950
962
result = getattr (values , meth )(axis , dtype = dtype_max )
@@ -954,12 +966,7 @@ def reduction(
954
966
else :
955
967
result = getattr (values , meth )(axis )
956
968
957
- result = _wrap_results (result , dtype , fill_value )
958
969
result = _maybe_null_out (result , axis , mask , values .shape )
959
-
960
- if datetimelike and not skipna :
961
- result = _mask_datetimelike_result (result , axis , mask , orig_values )
962
-
963
970
return result
964
971
965
972
return reduction
0 commit comments