diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index fd8536e38eee7..4bb1deffd9524 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -46,6 +46,7 @@ ) from .dtypes import DatetimeTZDtype, ExtensionDtype, PeriodDtype from .generic import ( + ABCDataFrame, ABCDatetimeArray, ABCDatetimeIndex, ABCPeriodArray, @@ -95,12 +96,13 @@ def maybe_downcast_to_dtype(result, dtype): """ try to cast to the specified dtype (e.g. convert back to bool/int or could be an astype of float64->float32 """ + do_round = False if is_scalar(result): return result - - def trans(x): - return x + elif isinstance(result, ABCDataFrame): + # occurs in pivot_table doctest + return result if isinstance(dtype, str): if dtype == "infer": @@ -118,83 +120,115 @@ def trans(x): elif inferred_type == "floating": dtype = "int64" if issubclass(result.dtype.type, np.number): - - def trans(x): # noqa - return x.round() + do_round = True else: dtype = "object" - if isinstance(dtype, str): dtype = np.dtype(dtype) - try: + converted = maybe_downcast_numeric(result, dtype, do_round) + if converted is not result: + return converted + + # a datetimelike + # GH12821, iNaT is casted to float + if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]: + try: + result = result.astype(dtype) + except Exception: + if dtype.tz: + # convert to datetime and change timezone + from pandas import to_datetime + + result = to_datetime(result).tz_localize("utc") + result = result.tz_convert(dtype.tz) + + elif dtype.type is Period: + # TODO(DatetimeArray): merge with previous elif + from pandas.core.arrays import PeriodArray + try: + return PeriodArray(result, freq=dtype.freq) + except TypeError: + # e.g. TypeError: int() argument must be a string, a + # bytes-like object or a number, not 'Period + pass + + return result + + +def maybe_downcast_numeric(result, dtype, do_round: bool = False): + """ + Subset of maybe_downcast_to_dtype restricted to numeric dtypes. + + Parameters + ---------- + result : ndarray or ExtensionArray + dtype : np.dtype or ExtensionDtype + do_round : bool + + Returns + ------- + ndarray or ExtensionArray + """ + if not isinstance(dtype, np.dtype): + # e.g. SparseDtype has no itemsize attr + return result + + if isinstance(result, list): + # reached via groupoby.agg _ohlc; really this should be handled + # earlier + result = np.array(result) + + def trans(x): + if do_round: + return x.round() + return x + + if dtype.kind == result.dtype.kind: # don't allow upcasts here (except if empty) - if dtype.kind == result.dtype.kind: - if result.dtype.itemsize <= dtype.itemsize and np.prod(result.shape): - return result + if result.dtype.itemsize <= dtype.itemsize and result.size: + return result - if is_bool_dtype(dtype) or is_integer_dtype(dtype): + if is_bool_dtype(dtype) or is_integer_dtype(dtype): + if not result.size: # if we don't have any elements, just astype it - if not np.prod(result.shape): - return trans(result).astype(dtype) + return trans(result).astype(dtype) - # do a test on the first element, if it fails then we are done - r = result.ravel() - arr = np.array([r[0]]) + # do a test on the first element, if it fails then we are done + r = result.ravel() + arr = np.array([r[0]]) + if isna(arr).any() or not np.allclose(arr, trans(arr).astype(dtype), rtol=0): # if we have any nulls, then we are done - if isna(arr).any() or not np.allclose( - arr, trans(arr).astype(dtype), rtol=0 - ): - return result + return result + elif not isinstance(r[0], (np.integer, np.floating, np.bool, int, float, bool)): # a comparable, e.g. a Decimal may slip in here - elif not isinstance( - r[0], (np.integer, np.floating, np.bool, int, float, bool) - ): - return result + return result - if ( - issubclass(result.dtype.type, (np.object_, np.number)) - and notna(result).all() - ): - new_result = trans(result).astype(dtype) - try: - if np.allclose(new_result, result, rtol=0): - return new_result - except Exception: - - # comparison of an object dtype with a number type could - # hit here - if (new_result == result).all(): - return new_result - elif issubclass(dtype.type, np.floating) and not is_bool_dtype(result.dtype): - return result.astype(dtype) - - # a datetimelike - # GH12821, iNaT is casted to float - elif dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]: + if ( + issubclass(result.dtype.type, (np.object_, np.number)) + and notna(result).all() + ): + new_result = trans(result).astype(dtype) try: - result = result.astype(dtype) + if np.allclose(new_result, result, rtol=0): + return new_result except Exception: - if dtype.tz: - # convert to datetime and change timezone - from pandas import to_datetime - - result = to_datetime(result).tz_localize("utc") - result = result.tz_convert(dtype.tz) - - elif dtype.type == Period: - # TODO(DatetimeArray): merge with previous elif - from pandas.core.arrays import PeriodArray - - return PeriodArray(result, freq=dtype.freq) - - except Exception: - pass + # comparison of an object dtype with a number type could + # hit here + if (new_result == result).all(): + return new_result + + elif ( + issubclass(dtype.type, np.floating) + and not is_bool_dtype(result.dtype) + and not is_string_dtype(result.dtype) + ): + return result.astype(dtype) return result