Skip to content

Commit 482141c

Browse files
jbrockmendelyeshsurya
authored andcommitted
TYP: timedeltas.pyi (pandas-dev#40766)
1 parent 38af11c commit 482141c

File tree

1 file changed

+73
-17
lines changed

1 file changed

+73
-17
lines changed

pandas/core/dtypes/cast.py

+73-17
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
from contextlib import suppress
78
from datetime import (
89
date,
910
datetime,
@@ -28,6 +29,7 @@
2829
NaT,
2930
OutOfBoundsDatetime,
3031
OutOfBoundsTimedelta,
32+
Period,
3133
Timedelta,
3234
Timestamp,
3335
conversion,
@@ -55,6 +57,7 @@
5557
ensure_str,
5658
is_bool,
5759
is_bool_dtype,
60+
is_categorical_dtype,
5861
is_complex,
5962
is_complex_dtype,
6063
is_datetime64_dtype,
@@ -78,13 +81,13 @@
7881
pandas_dtype,
7982
)
8083
from pandas.core.dtypes.dtypes import (
81-
CategoricalDtype,
8284
DatetimeTZDtype,
8385
ExtensionDtype,
8486
IntervalDtype,
8587
PeriodDtype,
8688
)
8789
from pandas.core.dtypes.generic import (
90+
ABCDataFrame,
8891
ABCExtensionArray,
8992
ABCSeries,
9093
)
@@ -189,13 +192,13 @@ def maybe_box_native(value: Scalar) -> Scalar:
189192
value = maybe_box_datetimelike(value)
190193
elif is_float(value):
191194
# error: Argument 1 to "float" has incompatible type
192-
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
195+
# "Union[Union[str, int, float, bool], Union[Any, Any, Timedelta, Any]]";
193196
# expected "Union[SupportsFloat, _SupportsIndex, str]"
194197
value = float(value) # type: ignore[arg-type]
195198
elif is_integer(value):
196199
# error: Argument 1 to "int" has incompatible type
197-
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
198-
# expected "Union[str, SupportsInt, _SupportsIndex, _SupportsTrunc]"
200+
# "Union[Union[str, int, float, bool], Union[Any, Any, Timedelta, Any]]";
201+
# pected "Union[str, SupportsInt, _SupportsIndex, _SupportsTrunc]"
199202
value = int(value) # type: ignore[arg-type]
200203
elif is_bool(value):
201204
value = bool(value)
@@ -246,6 +249,9 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
246249
try to cast to the specified dtype (e.g. convert back to bool/int
247250
or could be an astype of float64->float32
248251
"""
252+
if isinstance(result, ABCDataFrame):
253+
# see test_pivot_table_doctest_case
254+
return result
249255
do_round = False
250256

251257
if isinstance(dtype, str):
@@ -272,9 +278,15 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
272278

273279
dtype = np.dtype(dtype)
274280

275-
if not isinstance(dtype, np.dtype):
276-
# enforce our signature annotation
277-
raise TypeError(dtype) # pragma: no cover
281+
elif dtype.type is Period:
282+
from pandas.core.arrays import PeriodArray
283+
284+
with suppress(TypeError):
285+
# e.g. TypeError: int() argument must be a string, a
286+
# bytes-like object or a number, not 'Period
287+
288+
# error: "dtype[Any]" has no attribute "freq"
289+
return PeriodArray(result, freq=dtype.freq) # type: ignore[attr-defined]
278290

279291
converted = maybe_downcast_numeric(result, dtype, do_round)
280292
if converted is not result:
@@ -283,7 +295,15 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
283295
# a datetimelike
284296
# GH12821, iNaT is cast to float
285297
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
286-
result = result.astype(dtype)
298+
if isinstance(dtype, DatetimeTZDtype):
299+
# convert to datetime and change timezone
300+
i8values = result.astype("i8", copy=False)
301+
cls = dtype.construct_array_type()
302+
# equiv: DatetimeArray(i8values).tz_localize("UTC").tz_convert(dtype.tz)
303+
dt64values = i8values.view("M8[ns]")
304+
result = cls._simple_new(dt64values, dtype=dtype)
305+
else:
306+
result = result.astype(dtype)
287307

288308
return result
289309

@@ -359,15 +379,15 @@ def trans(x):
359379
return result
360380

361381

362-
def maybe_cast_pointwise_result(
382+
def maybe_cast_result(
363383
result: ArrayLike,
364384
dtype: DtypeObj,
365385
numeric_only: bool = False,
386+
how: str = "",
366387
same_dtype: bool = True,
367388
) -> ArrayLike:
368389
"""
369-
Try casting result of a pointwise operation back to the original dtype if
370-
appropriate.
390+
Try casting result to a different type if appropriate
371391
372392
Parameters
373393
----------
@@ -377,6 +397,8 @@ def maybe_cast_pointwise_result(
377397
Input Series from which result was calculated.
378398
numeric_only : bool, default False
379399
Whether to cast only numerics or datetimes as well.
400+
how : str, default ""
401+
How the result was computed.
380402
same_dtype : bool, default True
381403
Specify dtype when calling _from_sequence
382404
@@ -385,12 +407,12 @@ def maybe_cast_pointwise_result(
385407
result : array-like
386408
result maybe casted to the dtype.
387409
"""
410+
dtype = maybe_cast_result_dtype(dtype, how)
388411

389412
assert not is_scalar(result)
390413

391414
if isinstance(dtype, ExtensionDtype):
392-
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
393-
# TODO: avoid this special-casing
415+
if not is_categorical_dtype(dtype) and dtype.kind != "M":
394416
# We have to special case categorical so as not to upcast
395417
# things like counts back to categorical
396418

@@ -406,6 +428,42 @@ def maybe_cast_pointwise_result(
406428
return result
407429

408430

431+
def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
432+
"""
433+
Get the desired dtype of a result based on the
434+
input dtype and how it was computed.
435+
436+
Parameters
437+
----------
438+
dtype : DtypeObj
439+
Input dtype.
440+
how : str
441+
How the result was computed.
442+
443+
Returns
444+
-------
445+
DtypeObj
446+
The desired dtype of the result.
447+
"""
448+
from pandas.core.arrays.boolean import BooleanDtype
449+
from pandas.core.arrays.floating import Float64Dtype
450+
from pandas.core.arrays.integer import (
451+
Int64Dtype,
452+
_IntegerDtype,
453+
)
454+
455+
if how in ["add", "cumsum", "sum", "prod"]:
456+
if dtype == np.dtype(bool):
457+
return np.dtype(np.int64)
458+
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
459+
return Int64Dtype()
460+
elif how in ["mean", "median", "var"] and isinstance(
461+
dtype, (BooleanDtype, _IntegerDtype)
462+
):
463+
return Float64Dtype()
464+
return dtype
465+
466+
409467
def maybe_cast_to_extension_array(
410468
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
411469
) -> ArrayLike:
@@ -729,9 +787,7 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> tuple[DtypeObj,
729787
except OutOfBoundsDatetime:
730788
return np.dtype(object), val
731789

732-
# error: Non-overlapping identity check (left operand type: "Timestamp",
733-
# right operand type: "NaTType")
734-
if val is NaT or val.tz is None: # type: ignore[comparison-overlap]
790+
if val is NaT or val.tz is None:
735791
dtype = np.dtype("M8[ns]")
736792
val = val.to_datetime64()
737793
else:
@@ -2058,7 +2114,7 @@ def validate_numeric_casting(dtype: np.dtype, value: Scalar) -> None:
20582114
ValueError
20592115
"""
20602116
# error: Argument 1 to "__call__" of "ufunc" has incompatible type
2061-
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
2117+
# "Union[Union[str, int, float, bool], Union[Any, Any, Timedelta, Any]]";
20622118
# expected "Union[Union[int, float, complex, str, bytes, generic],
20632119
# Sequence[Union[int, float, complex, str, bytes, generic]],
20642120
# Sequence[Sequence[Any]], _SupportsArray]"

0 commit comments

Comments
 (0)