7
7
8
8
from pandas ._config import get_option
9
9
10
- from pandas ._libs import NaT , Timedelta , Timestamp , iNaT , lib
10
+ from pandas ._libs import NaT , Period , Timedelta , Timestamp , iNaT , lib
11
11
from pandas ._typing import Dtype , Scalar
12
12
from pandas .compat ._optional import import_optional_dependency
13
13
17
17
is_any_int_dtype ,
18
18
is_bool_dtype ,
19
19
is_complex ,
20
- is_datetime64_dtype ,
21
- is_datetime64tz_dtype ,
22
- is_datetime_or_timedelta_dtype ,
20
+ is_datetime64_any_dtype ,
23
21
is_float ,
24
22
is_float_dtype ,
25
23
is_integer ,
28
26
is_object_dtype ,
29
27
is_scalar ,
30
28
is_timedelta64_dtype ,
29
+ needs_i8_conversion ,
31
30
pandas_dtype ,
32
31
)
32
+ from pandas .core .dtypes .dtypes import PeriodDtype
33
33
from pandas .core .dtypes .missing import isna , na_value_for_dtype , notna
34
34
35
35
from pandas .core .construction import extract_array
@@ -134,10 +134,8 @@ def f(
134
134
135
135
136
136
def _bn_ok_dtype (dtype : Dtype , name : str ) -> bool :
137
- # Bottleneck chokes on datetime64
138
- if not is_object_dtype (dtype ) and not (
139
- is_datetime_or_timedelta_dtype (dtype ) or is_datetime64tz_dtype (dtype )
140
- ):
137
+ # Bottleneck chokes on datetime64, PeriodDtype (or and EA)
138
+ if not is_object_dtype (dtype ) and not needs_i8_conversion (dtype ):
141
139
142
140
# GH 15507
143
141
# bottleneck does not properly upcast during the sum
@@ -283,17 +281,16 @@ def _get_values(
283
281
# with scalar fill_value. This guarantee is important for the
284
282
# maybe_upcast_putmask call below
285
283
assert is_scalar (fill_value )
284
+ values = extract_array (values , extract_numpy = True )
286
285
287
286
mask = _maybe_get_mask (values , skipna , mask )
288
287
289
- values = extract_array (values , extract_numpy = True )
290
288
dtype = values .dtype
291
289
292
- if is_datetime_or_timedelta_dtype ( values ) or is_datetime64tz_dtype (values ):
290
+ if needs_i8_conversion (values ):
293
291
# changing timedelta64/datetime64 to int64 needs to happen after
294
292
# finding `mask` above
295
- values = getattr (values , "asi8" , values )
296
- values = values .view (np .int64 )
293
+ values = np .asarray (values .view ("i8" ))
297
294
298
295
dtype_ok = _na_ok_dtype (dtype )
299
296
@@ -307,7 +304,8 @@ def _get_values(
307
304
308
305
if skipna and copy :
309
306
values = values .copy ()
310
- if dtype_ok :
307
+ assert mask is not None # for mypy
308
+ if dtype_ok and mask .any ():
311
309
np .putmask (values , mask , fill_value )
312
310
313
311
# promote if needed
@@ -325,13 +323,14 @@ def _get_values(
325
323
326
324
327
325
def _na_ok_dtype (dtype ) -> bool :
328
- # TODO: what about datetime64tz? PeriodDtype?
329
- return not issubclass (dtype .type , (np .integer , np .timedelta64 , np .datetime64 ))
326
+ if needs_i8_conversion (dtype ):
327
+ return False
328
+ return not issubclass (dtype .type , np .integer )
330
329
331
330
332
331
def _wrap_results (result , dtype : Dtype , fill_value = None ):
333
332
""" wrap our results if needed """
334
- if is_datetime64_dtype ( dtype ) or is_datetime64tz_dtype (dtype ):
333
+ if is_datetime64_any_dtype (dtype ):
335
334
if fill_value is None :
336
335
# GH#24293
337
336
fill_value = iNaT
@@ -342,7 +341,8 @@ def _wrap_results(result, dtype: Dtype, fill_value=None):
342
341
result = np .nan
343
342
result = Timestamp (result , tz = tz )
344
343
else :
345
- result = result .view (dtype )
344
+ # If we have float dtype, taking a view will give the wrong result
345
+ result = result .astype (dtype )
346
346
elif is_timedelta64_dtype (dtype ):
347
347
if not isinstance (result , np .ndarray ):
348
348
if result == fill_value :
@@ -356,6 +356,14 @@ def _wrap_results(result, dtype: Dtype, fill_value=None):
356
356
else :
357
357
result = result .astype ("m8[ns]" ).view (dtype )
358
358
359
+ elif isinstance (dtype , PeriodDtype ):
360
+ if is_float (result ) and result .is_integer ():
361
+ result = int (result )
362
+ if is_integer (result ):
363
+ result = Period ._from_ordinal (result , freq = dtype .freq )
364
+ else :
365
+ raise NotImplementedError (type (result ), result )
366
+
359
367
return result
360
368
361
369
@@ -542,12 +550,7 @@ def nanmean(values, axis=None, skipna=True, mask=None):
542
550
)
543
551
dtype_sum = dtype_max
544
552
dtype_count = np .float64
545
- if (
546
- is_integer_dtype (dtype )
547
- or is_timedelta64_dtype (dtype )
548
- or is_datetime64_dtype (dtype )
549
- or is_datetime64tz_dtype (dtype )
550
- ):
553
+ if is_integer_dtype (dtype ) or needs_i8_conversion (dtype ):
551
554
dtype_sum = np .float64
552
555
elif is_float_dtype (dtype ):
553
556
dtype_sum = dtype
0 commit comments