46
46
)
47
47
from .dtypes import DatetimeTZDtype , ExtensionDtype , PeriodDtype
48
48
from .generic import (
49
+ ABCDataFrame ,
49
50
ABCDatetimeArray ,
50
51
ABCDatetimeIndex ,
51
52
ABCPeriodArray ,
@@ -95,12 +96,13 @@ def maybe_downcast_to_dtype(result, dtype):
95
96
""" try to cast to the specified dtype (e.g. convert back to bool/int
96
97
or could be an astype of float64->float32
97
98
"""
99
+ do_round = False
98
100
99
101
if is_scalar (result ):
100
102
return result
101
-
102
- def trans ( x ):
103
- return x
103
+ elif isinstance ( result , ABCDataFrame ):
104
+ # occurs in pivot_table doctest
105
+ return result
104
106
105
107
if isinstance (dtype , str ):
106
108
if dtype == "infer" :
@@ -118,83 +120,115 @@ def trans(x):
118
120
elif inferred_type == "floating" :
119
121
dtype = "int64"
120
122
if issubclass (result .dtype .type , np .number ):
121
-
122
- def trans (x ): # noqa
123
- return x .round ()
123
+ do_round = True
124
124
125
125
else :
126
126
dtype = "object"
127
127
128
- if isinstance (dtype , str ):
129
128
dtype = np .dtype (dtype )
130
129
131
- try :
130
+ converted = maybe_downcast_numeric (result , dtype , do_round )
131
+ if converted is not result :
132
+ return converted
133
+
134
+ # a datetimelike
135
+ # GH12821, iNaT is casted to float
136
+ if dtype .kind in ["M" , "m" ] and result .dtype .kind in ["i" , "f" ]:
137
+ try :
138
+ result = result .astype (dtype )
139
+ except Exception :
140
+ if dtype .tz :
141
+ # convert to datetime and change timezone
142
+ from pandas import to_datetime
143
+
144
+ result = to_datetime (result ).tz_localize ("utc" )
145
+ result = result .tz_convert (dtype .tz )
146
+
147
+ elif dtype .type is Period :
148
+ # TODO(DatetimeArray): merge with previous elif
149
+ from pandas .core .arrays import PeriodArray
132
150
151
+ try :
152
+ return PeriodArray (result , freq = dtype .freq )
153
+ except TypeError :
154
+ # e.g. TypeError: int() argument must be a string, a
155
+ # bytes-like object or a number, not 'Period
156
+ pass
157
+
158
+ return result
159
+
160
+
161
+ def maybe_downcast_numeric (result , dtype , do_round : bool = False ):
162
+ """
163
+ Subset of maybe_downcast_to_dtype restricted to numeric dtypes.
164
+
165
+ Parameters
166
+ ----------
167
+ result : ndarray or ExtensionArray
168
+ dtype : np.dtype or ExtensionDtype
169
+ do_round : bool
170
+
171
+ Returns
172
+ -------
173
+ ndarray or ExtensionArray
174
+ """
175
+ if not isinstance (dtype , np .dtype ):
176
+ # e.g. SparseDtype has no itemsize attr
177
+ return result
178
+
179
+ if isinstance (result , list ):
180
+ # reached via groupoby.agg _ohlc; really this should be handled
181
+ # earlier
182
+ result = np .array (result )
183
+
184
+ def trans (x ):
185
+ if do_round :
186
+ return x .round ()
187
+ return x
188
+
189
+ if dtype .kind == result .dtype .kind :
133
190
# don't allow upcasts here (except if empty)
134
- if dtype .kind == result .dtype .kind :
135
- if result .dtype .itemsize <= dtype .itemsize and np .prod (result .shape ):
136
- return result
191
+ if result .dtype .itemsize <= dtype .itemsize and result .size :
192
+ return result
137
193
138
- if is_bool_dtype (dtype ) or is_integer_dtype (dtype ):
194
+ if is_bool_dtype (dtype ) or is_integer_dtype (dtype ):
139
195
196
+ if not result .size :
140
197
# if we don't have any elements, just astype it
141
- if not np .prod (result .shape ):
142
- return trans (result ).astype (dtype )
198
+ return trans (result ).astype (dtype )
143
199
144
- # do a test on the first element, if it fails then we are done
145
- r = result .ravel ()
146
- arr = np .array ([r [0 ]])
200
+ # do a test on the first element, if it fails then we are done
201
+ r = result .ravel ()
202
+ arr = np .array ([r [0 ]])
147
203
204
+ if isna (arr ).any () or not np .allclose (arr , trans (arr ).astype (dtype ), rtol = 0 ):
148
205
# if we have any nulls, then we are done
149
- if isna (arr ).any () or not np .allclose (
150
- arr , trans (arr ).astype (dtype ), rtol = 0
151
- ):
152
- return result
206
+ return result
153
207
208
+ elif not isinstance (r [0 ], (np .integer , np .floating , np .bool , int , float , bool )):
154
209
# a comparable, e.g. a Decimal may slip in here
155
- elif not isinstance (
156
- r [0 ], (np .integer , np .floating , np .bool , int , float , bool )
157
- ):
158
- return result
210
+ return result
159
211
160
- if (
161
- issubclass (result .dtype .type , (np .object_ , np .number ))
162
- and notna (result ).all ()
163
- ):
164
- new_result = trans (result ).astype (dtype )
165
- try :
166
- if np .allclose (new_result , result , rtol = 0 ):
167
- return new_result
168
- except Exception :
169
-
170
- # comparison of an object dtype with a number type could
171
- # hit here
172
- if (new_result == result ).all ():
173
- return new_result
174
- elif issubclass (dtype .type , np .floating ) and not is_bool_dtype (result .dtype ):
175
- return result .astype (dtype )
176
-
177
- # a datetimelike
178
- # GH12821, iNaT is casted to float
179
- elif dtype .kind in ["M" , "m" ] and result .dtype .kind in ["i" , "f" ]:
212
+ if (
213
+ issubclass (result .dtype .type , (np .object_ , np .number ))
214
+ and notna (result ).all ()
215
+ ):
216
+ new_result = trans (result ).astype (dtype )
180
217
try :
181
- result = result .astype (dtype )
218
+ if np .allclose (new_result , result , rtol = 0 ):
219
+ return new_result
182
220
except Exception :
183
- if dtype .tz :
184
- # convert to datetime and change timezone
185
- from pandas import to_datetime
186
-
187
- result = to_datetime (result ).tz_localize ("utc" )
188
- result = result .tz_convert (dtype .tz )
189
-
190
- elif dtype .type == Period :
191
- # TODO(DatetimeArray): merge with previous elif
192
- from pandas .core .arrays import PeriodArray
193
-
194
- return PeriodArray (result , freq = dtype .freq )
195
-
196
- except Exception :
197
- pass
221
+ # comparison of an object dtype with a number type could
222
+ # hit here
223
+ if (new_result == result ).all ():
224
+ return new_result
225
+
226
+ elif (
227
+ issubclass (dtype .type , np .floating )
228
+ and not is_bool_dtype (result .dtype )
229
+ and not is_string_dtype (result .dtype )
230
+ ):
231
+ return result .astype (dtype )
198
232
199
233
return result
200
234
0 commit comments