4
4
5
5
from __future__ import annotations
6
6
7
+ from contextlib import suppress
7
8
from datetime import (
8
9
date ,
9
10
datetime ,
28
29
NaT ,
29
30
OutOfBoundsDatetime ,
30
31
OutOfBoundsTimedelta ,
32
+ Period ,
31
33
Timedelta ,
32
34
Timestamp ,
33
35
conversion ,
55
57
ensure_str ,
56
58
is_bool ,
57
59
is_bool_dtype ,
60
+ is_categorical_dtype ,
58
61
is_complex ,
59
62
is_complex_dtype ,
60
63
is_datetime64_dtype ,
78
81
pandas_dtype ,
79
82
)
80
83
from pandas .core .dtypes .dtypes import (
81
- CategoricalDtype ,
82
84
DatetimeTZDtype ,
83
85
ExtensionDtype ,
84
86
IntervalDtype ,
85
87
PeriodDtype ,
86
88
)
87
89
from pandas .core .dtypes .generic import (
90
+ ABCDataFrame ,
88
91
ABCExtensionArray ,
89
92
ABCSeries ,
90
93
)
@@ -189,13 +192,13 @@ def maybe_box_native(value: Scalar) -> Scalar:
189
192
value = maybe_box_datetimelike (value )
190
193
elif is_float (value ):
191
194
# error: Argument 1 to "float" has incompatible type
192
- # "Union[Union[str, int, float, bool], Union[Any, Timestamp , Timedelta, Any]]";
195
+ # "Union[Union[str, int, float, bool], Union[Any, Any , Timedelta, Any]]";
193
196
# expected "Union[SupportsFloat, _SupportsIndex, str]"
194
197
value = float (value ) # type: ignore[arg-type]
195
198
elif is_integer (value ):
196
199
# error: Argument 1 to "int" has incompatible type
197
- # "Union[Union[str, int, float, bool], Union[Any, Timestamp , Timedelta, Any]]";
198
- # expected "Union[str, SupportsInt, _SupportsIndex, _SupportsTrunc]"
200
+ # "Union[Union[str, int, float, bool], Union[Any, Any , Timedelta, Any]]";
201
+ # pected "Union[str, SupportsInt, _SupportsIndex, _SupportsTrunc]"
199
202
value = int (value ) # type: ignore[arg-type]
200
203
elif is_bool (value ):
201
204
value = bool (value )
@@ -246,6 +249,9 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
246
249
try to cast to the specified dtype (e.g. convert back to bool/int
247
250
or could be an astype of float64->float32
248
251
"""
252
+ if isinstance (result , ABCDataFrame ):
253
+ # see test_pivot_table_doctest_case
254
+ return result
249
255
do_round = False
250
256
251
257
if isinstance (dtype , str ):
@@ -272,9 +278,15 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
272
278
273
279
dtype = np .dtype (dtype )
274
280
275
- if not isinstance (dtype , np .dtype ):
276
- # enforce our signature annotation
277
- raise TypeError (dtype ) # pragma: no cover
281
+ elif dtype .type is Period :
282
+ from pandas .core .arrays import PeriodArray
283
+
284
+ with suppress (TypeError ):
285
+ # e.g. TypeError: int() argument must be a string, a
286
+ # bytes-like object or a number, not 'Period
287
+
288
+ # error: "dtype[Any]" has no attribute "freq"
289
+ return PeriodArray (result , freq = dtype .freq ) # type: ignore[attr-defined]
278
290
279
291
converted = maybe_downcast_numeric (result , dtype , do_round )
280
292
if converted is not result :
@@ -283,7 +295,15 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
283
295
# a datetimelike
284
296
# GH12821, iNaT is cast to float
285
297
if dtype .kind in ["M" , "m" ] and result .dtype .kind in ["i" , "f" ]:
286
- result = result .astype (dtype )
298
+ if isinstance (dtype , DatetimeTZDtype ):
299
+ # convert to datetime and change timezone
300
+ i8values = result .astype ("i8" , copy = False )
301
+ cls = dtype .construct_array_type ()
302
+ # equiv: DatetimeArray(i8values).tz_localize("UTC").tz_convert(dtype.tz)
303
+ dt64values = i8values .view ("M8[ns]" )
304
+ result = cls ._simple_new (dt64values , dtype = dtype )
305
+ else :
306
+ result = result .astype (dtype )
287
307
288
308
return result
289
309
@@ -359,15 +379,15 @@ def trans(x):
359
379
return result
360
380
361
381
362
- def maybe_cast_pointwise_result (
382
+ def maybe_cast_result (
363
383
result : ArrayLike ,
364
384
dtype : DtypeObj ,
365
385
numeric_only : bool = False ,
386
+ how : str = "" ,
366
387
same_dtype : bool = True ,
367
388
) -> ArrayLike :
368
389
"""
369
- Try casting result of a pointwise operation back to the original dtype if
370
- appropriate.
390
+ Try casting result to a different type if appropriate
371
391
372
392
Parameters
373
393
----------
@@ -377,6 +397,8 @@ def maybe_cast_pointwise_result(
377
397
Input Series from which result was calculated.
378
398
numeric_only : bool, default False
379
399
Whether to cast only numerics or datetimes as well.
400
+ how : str, default ""
401
+ How the result was computed.
380
402
same_dtype : bool, default True
381
403
Specify dtype when calling _from_sequence
382
404
@@ -385,12 +407,12 @@ def maybe_cast_pointwise_result(
385
407
result : array-like
386
408
result maybe casted to the dtype.
387
409
"""
410
+ dtype = maybe_cast_result_dtype (dtype , how )
388
411
389
412
assert not is_scalar (result )
390
413
391
414
if isinstance (dtype , ExtensionDtype ):
392
- if not isinstance (dtype , (CategoricalDtype , DatetimeTZDtype )):
393
- # TODO: avoid this special-casing
415
+ if not is_categorical_dtype (dtype ) and dtype .kind != "M" :
394
416
# We have to special case categorical so as not to upcast
395
417
# things like counts back to categorical
396
418
@@ -406,6 +428,42 @@ def maybe_cast_pointwise_result(
406
428
return result
407
429
408
430
431
+ def maybe_cast_result_dtype (dtype : DtypeObj , how : str ) -> DtypeObj :
432
+ """
433
+ Get the desired dtype of a result based on the
434
+ input dtype and how it was computed.
435
+
436
+ Parameters
437
+ ----------
438
+ dtype : DtypeObj
439
+ Input dtype.
440
+ how : str
441
+ How the result was computed.
442
+
443
+ Returns
444
+ -------
445
+ DtypeObj
446
+ The desired dtype of the result.
447
+ """
448
+ from pandas .core .arrays .boolean import BooleanDtype
449
+ from pandas .core .arrays .floating import Float64Dtype
450
+ from pandas .core .arrays .integer import (
451
+ Int64Dtype ,
452
+ _IntegerDtype ,
453
+ )
454
+
455
+ if how in ["add" , "cumsum" , "sum" , "prod" ]:
456
+ if dtype == np .dtype (bool ):
457
+ return np .dtype (np .int64 )
458
+ elif isinstance (dtype , (BooleanDtype , _IntegerDtype )):
459
+ return Int64Dtype ()
460
+ elif how in ["mean" , "median" , "var" ] and isinstance (
461
+ dtype , (BooleanDtype , _IntegerDtype )
462
+ ):
463
+ return Float64Dtype ()
464
+ return dtype
465
+
466
+
409
467
def maybe_cast_to_extension_array (
410
468
cls : type [ExtensionArray ], obj : ArrayLike , dtype : ExtensionDtype | None = None
411
469
) -> ArrayLike :
@@ -729,9 +787,7 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> tuple[DtypeObj,
729
787
except OutOfBoundsDatetime :
730
788
return np .dtype (object ), val
731
789
732
- # error: Non-overlapping identity check (left operand type: "Timestamp",
733
- # right operand type: "NaTType")
734
- if val is NaT or val .tz is None : # type: ignore[comparison-overlap]
790
+ if val is NaT or val .tz is None :
735
791
dtype = np .dtype ("M8[ns]" )
736
792
val = val .to_datetime64 ()
737
793
else :
@@ -2058,7 +2114,7 @@ def validate_numeric_casting(dtype: np.dtype, value: Scalar) -> None:
2058
2114
ValueError
2059
2115
"""
2060
2116
# error: Argument 1 to "__call__" of "ufunc" has incompatible type
2061
- # "Union[Union[str, int, float, bool], Union[Any, Timestamp , Timedelta, Any]]";
2117
+ # "Union[Union[str, int, float, bool], Union[Any, Any , Timedelta, Any]]";
2062
2118
# expected "Union[Union[int, float, complex, str, bytes, generic],
2063
2119
# Sequence[Union[int, float, complex, str, bytes, generic]],
2064
2120
# Sequence[Sequence[Any]], _SupportsArray]"
0 commit comments