Skip to content

Commit 4340689

Browse files
authored
REF: enforce annotation in maybe_downcast_to_dtype (#40982)
1 parent ece1217 commit 4340689

File tree

4 files changed

+21
-48
lines changed

4 files changed

+21
-48
lines changed

pandas/core/dtypes/cast.py

+4-24
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from __future__ import annotations
66

7-
from contextlib import suppress
87
from datetime import (
98
date,
109
datetime,
@@ -29,7 +28,6 @@
2928
NaT,
3029
OutOfBoundsDatetime,
3130
OutOfBoundsTimedelta,
32-
Period,
3331
Timedelta,
3432
Timestamp,
3533
conversion,
@@ -87,7 +85,6 @@
8785
PeriodDtype,
8886
)
8987
from pandas.core.dtypes.generic import (
90-
ABCDataFrame,
9188
ABCExtensionArray,
9289
ABCSeries,
9390
)
@@ -249,9 +246,6 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
249246
try to cast to the specified dtype (e.g. convert back to bool/int
250247
or could be an astype of float64->float32
251248
"""
252-
if isinstance(result, ABCDataFrame):
253-
# see test_pivot_table_doctest_case
254-
return result
255249
do_round = False
256250

257251
if isinstance(dtype, str):
@@ -278,15 +272,9 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
278272

279273
dtype = np.dtype(dtype)
280274

281-
elif dtype.type is Period:
282-
from pandas.core.arrays import PeriodArray
283-
284-
with suppress(TypeError):
285-
# e.g. TypeError: int() argument must be a string, a
286-
# bytes-like object or a number, not 'Period
287-
288-
# error: "dtype[Any]" has no attribute "freq"
289-
return PeriodArray(result, freq=dtype.freq) # type: ignore[attr-defined]
275+
if not isinstance(dtype, np.dtype):
276+
# enforce our signature annotation
277+
raise TypeError(dtype) # pragma: no cover
290278

291279
converted = maybe_downcast_numeric(result, dtype, do_round)
292280
if converted is not result:
@@ -295,15 +283,7 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
295283
# a datetimelike
296284
# GH12821, iNaT is cast to float
297285
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
298-
if isinstance(dtype, DatetimeTZDtype):
299-
# convert to datetime and change timezone
300-
i8values = result.astype("i8", copy=False)
301-
cls = dtype.construct_array_type()
302-
# equiv: DatetimeArray(i8values).tz_localize("UTC").tz_convert(dtype.tz)
303-
dt64values = i8values.view("M8[ns]")
304-
result = cls._simple_new(dt64values, dtype=dtype)
305-
else:
306-
result = result.astype(dtype)
286+
result = result.astype(dtype)
307287

308288
return result
309289

pandas/core/frame.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7213,13 +7213,14 @@ def combine(
72137213
else:
72147214
# if we have different dtypes, possibly promote
72157215
new_dtype = find_common_type([this_dtype, other_dtype])
7216-
if not is_dtype_equal(this_dtype, new_dtype):
7217-
series = series.astype(new_dtype)
7218-
if not is_dtype_equal(other_dtype, new_dtype):
7219-
otherSeries = otherSeries.astype(new_dtype)
7216+
series = series.astype(new_dtype, copy=False)
7217+
otherSeries = otherSeries.astype(new_dtype, copy=False)
72207218

72217219
arr = func(series, otherSeries)
7222-
arr = maybe_downcast_to_dtype(arr, new_dtype)
7220+
if isinstance(new_dtype, np.dtype):
7221+
# if new_dtype is an EA Dtype, then `func` is expected to return
7222+
# the correct dtype without any additional casting
7223+
arr = maybe_downcast_to_dtype(arr, new_dtype)
72237224

72247225
result[col] = arr
72257226

pandas/core/reshape/pivot.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,15 @@ def __internal_pivot_table(
174174
and v in agged
175175
and not is_integer_dtype(agged[v])
176176
):
177-
agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
177+
if isinstance(agged[v], ABCDataFrame):
178+
# exclude DataFrame case bc maybe_downcast_to_dtype expects
179+
# ArrayLike
180+
# TODO: why does test_pivot_table_doctest_case fail if
181+
# we don't do this apparently-unnecessary setitem?
182+
agged[v] = agged[v]
183+
pass
184+
else:
185+
agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
178186

179187
table = agged
180188

pandas/tests/dtypes/cast/test_downcast.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
77

8-
from pandas import (
9-
DatetimeIndex,
10-
Series,
11-
Timestamp,
12-
)
8+
from pandas import Series
139
import pandas._testing as tm
1410

1511

@@ -77,7 +73,7 @@ def test_downcast_conversion_nan(float_dtype):
7773
def test_downcast_conversion_empty(any_real_dtype):
7874
dtype = any_real_dtype
7975
arr = np.array([], dtype=dtype)
80-
result = maybe_downcast_to_dtype(arr, "int64")
76+
result = maybe_downcast_to_dtype(arr, np.dtype("int64"))
8177
tm.assert_numpy_array_equal(result, np.array([], dtype=np.int64))
8278

8379

@@ -89,15 +85,3 @@ def test_datetime_likes_nan(klass):
8985
exp = np.array([1, 2, klass("NaT")], dtype)
9086
res = maybe_downcast_to_dtype(arr, dtype)
9187
tm.assert_numpy_array_equal(res, exp)
92-
93-
94-
@pytest.mark.parametrize("as_asi", [True, False])
95-
def test_datetime_with_timezone(as_asi):
96-
# see gh-15426
97-
ts = Timestamp("2016-01-01 12:00:00", tz="US/Pacific")
98-
exp = DatetimeIndex([ts, ts])._data
99-
100-
obj = exp.asi8 if as_asi else exp
101-
res = maybe_downcast_to_dtype(obj, exp.dtype)
102-
103-
tm.assert_datetime_array_equal(res, exp)

0 commit comments

Comments
 (0)