Skip to content

Commit 6af6d51

Browse files
jbrockmendeljreback
authored andcommitted
REF/CLN: maybe_downcast_to_dtype (#27714)
1 parent 447d2c5 commit 6af6d51

File tree

1 file changed

+95
-61
lines changed

1 file changed

+95
-61
lines changed

pandas/core/dtypes/cast.py

+95-61
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from .dtypes import DatetimeTZDtype, ExtensionDtype, PeriodDtype
4848
from .generic import (
49+
ABCDataFrame,
4950
ABCDatetimeArray,
5051
ABCDatetimeIndex,
5152
ABCPeriodArray,
@@ -95,12 +96,13 @@ def maybe_downcast_to_dtype(result, dtype):
9596
""" try to cast to the specified dtype (e.g. convert back to bool/int
9697
or could be an astype of float64->float32
9798
"""
99+
do_round = False
98100

99101
if is_scalar(result):
100102
return result
101-
102-
def trans(x):
103-
return x
103+
elif isinstance(result, ABCDataFrame):
104+
# occurs in pivot_table doctest
105+
return result
104106

105107
if isinstance(dtype, str):
106108
if dtype == "infer":
@@ -118,83 +120,115 @@ def trans(x):
118120
elif inferred_type == "floating":
119121
dtype = "int64"
120122
if issubclass(result.dtype.type, np.number):
121-
122-
def trans(x): # noqa
123-
return x.round()
123+
do_round = True
124124

125125
else:
126126
dtype = "object"
127127

128-
if isinstance(dtype, str):
129128
dtype = np.dtype(dtype)
130129

131-
try:
130+
converted = maybe_downcast_numeric(result, dtype, do_round)
131+
if converted is not result:
132+
return converted
133+
134+
# a datetimelike
135+
# GH12821, iNaT is casted to float
136+
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
137+
try:
138+
result = result.astype(dtype)
139+
except Exception:
140+
if dtype.tz:
141+
# convert to datetime and change timezone
142+
from pandas import to_datetime
143+
144+
result = to_datetime(result).tz_localize("utc")
145+
result = result.tz_convert(dtype.tz)
146+
147+
elif dtype.type is Period:
148+
# TODO(DatetimeArray): merge with previous elif
149+
from pandas.core.arrays import PeriodArray
132150

151+
try:
152+
return PeriodArray(result, freq=dtype.freq)
153+
except TypeError:
154+
# e.g. TypeError: int() argument must be a string, a
155+
# bytes-like object or a number, not 'Period
156+
pass
157+
158+
return result
159+
160+
161+
def maybe_downcast_numeric(result, dtype, do_round: bool = False):
162+
"""
163+
Subset of maybe_downcast_to_dtype restricted to numeric dtypes.
164+
165+
Parameters
166+
----------
167+
result : ndarray or ExtensionArray
168+
dtype : np.dtype or ExtensionDtype
169+
do_round : bool
170+
171+
Returns
172+
-------
173+
ndarray or ExtensionArray
174+
"""
175+
if not isinstance(dtype, np.dtype):
176+
# e.g. SparseDtype has no itemsize attr
177+
return result
178+
179+
if isinstance(result, list):
180+
# reached via groupoby.agg _ohlc; really this should be handled
181+
# earlier
182+
result = np.array(result)
183+
184+
def trans(x):
185+
if do_round:
186+
return x.round()
187+
return x
188+
189+
if dtype.kind == result.dtype.kind:
133190
# don't allow upcasts here (except if empty)
134-
if dtype.kind == result.dtype.kind:
135-
if result.dtype.itemsize <= dtype.itemsize and np.prod(result.shape):
136-
return result
191+
if result.dtype.itemsize <= dtype.itemsize and result.size:
192+
return result
137193

138-
if is_bool_dtype(dtype) or is_integer_dtype(dtype):
194+
if is_bool_dtype(dtype) or is_integer_dtype(dtype):
139195

196+
if not result.size:
140197
# if we don't have any elements, just astype it
141-
if not np.prod(result.shape):
142-
return trans(result).astype(dtype)
198+
return trans(result).astype(dtype)
143199

144-
# do a test on the first element, if it fails then we are done
145-
r = result.ravel()
146-
arr = np.array([r[0]])
200+
# do a test on the first element, if it fails then we are done
201+
r = result.ravel()
202+
arr = np.array([r[0]])
147203

204+
if isna(arr).any() or not np.allclose(arr, trans(arr).astype(dtype), rtol=0):
148205
# if we have any nulls, then we are done
149-
if isna(arr).any() or not np.allclose(
150-
arr, trans(arr).astype(dtype), rtol=0
151-
):
152-
return result
206+
return result
153207

208+
elif not isinstance(r[0], (np.integer, np.floating, np.bool, int, float, bool)):
154209
# a comparable, e.g. a Decimal may slip in here
155-
elif not isinstance(
156-
r[0], (np.integer, np.floating, np.bool, int, float, bool)
157-
):
158-
return result
210+
return result
159211

160-
if (
161-
issubclass(result.dtype.type, (np.object_, np.number))
162-
and notna(result).all()
163-
):
164-
new_result = trans(result).astype(dtype)
165-
try:
166-
if np.allclose(new_result, result, rtol=0):
167-
return new_result
168-
except Exception:
169-
170-
# comparison of an object dtype with a number type could
171-
# hit here
172-
if (new_result == result).all():
173-
return new_result
174-
elif issubclass(dtype.type, np.floating) and not is_bool_dtype(result.dtype):
175-
return result.astype(dtype)
176-
177-
# a datetimelike
178-
# GH12821, iNaT is casted to float
179-
elif dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
212+
if (
213+
issubclass(result.dtype.type, (np.object_, np.number))
214+
and notna(result).all()
215+
):
216+
new_result = trans(result).astype(dtype)
180217
try:
181-
result = result.astype(dtype)
218+
if np.allclose(new_result, result, rtol=0):
219+
return new_result
182220
except Exception:
183-
if dtype.tz:
184-
# convert to datetime and change timezone
185-
from pandas import to_datetime
186-
187-
result = to_datetime(result).tz_localize("utc")
188-
result = result.tz_convert(dtype.tz)
189-
190-
elif dtype.type == Period:
191-
# TODO(DatetimeArray): merge with previous elif
192-
from pandas.core.arrays import PeriodArray
193-
194-
return PeriodArray(result, freq=dtype.freq)
195-
196-
except Exception:
197-
pass
221+
# comparison of an object dtype with a number type could
222+
# hit here
223+
if (new_result == result).all():
224+
return new_result
225+
226+
elif (
227+
issubclass(dtype.type, np.floating)
228+
and not is_bool_dtype(result.dtype)
229+
and not is_string_dtype(result.dtype)
230+
):
231+
return result.astype(dtype)
198232

199233
return result
200234

0 commit comments

Comments
 (0)