-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
REF/CLN: maybe_downcast_to_dtype #27714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
dc77ed4
cce455d
7aa52ef
1bb828d
85e542a
a3806ec
9315929
92f04fb
44bf1e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ | |
) | ||
from .dtypes import DatetimeTZDtype, ExtensionDtype, PeriodDtype | ||
from .generic import ( | ||
ABCDataFrame, | ||
ABCDatetimeArray, | ||
ABCDatetimeIndex, | ||
ABCPeriodArray, | ||
|
@@ -95,12 +96,13 @@ def maybe_downcast_to_dtype(result, dtype): | |
""" try to cast to the specified dtype (e.g. convert back to bool/int | ||
or could be an astype of float64->float32 | ||
""" | ||
do_round = False | ||
|
||
if is_scalar(result): | ||
return result | ||
|
||
def trans(x): | ||
return x | ||
elif isinstance(result, ABCDataFrame): | ||
# occurs in pivot_table doctest | ||
return result | ||
|
||
if isinstance(dtype, str): | ||
if dtype == "infer": | ||
|
@@ -118,83 +120,115 @@ def trans(x): | |
elif inferred_type == "floating": | ||
dtype = "int64" | ||
if issubclass(result.dtype.type, np.number): | ||
|
||
def trans(x): # noqa | ||
return x.round() | ||
do_round = True | ||
|
||
else: | ||
dtype = "object" | ||
|
||
if isinstance(dtype, str): | ||
dtype = np.dtype(dtype) | ||
|
||
try: | ||
converted = maybe_downcast_numeric(result, dtype, do_round) | ||
if converted is not result: | ||
return converted | ||
|
||
# a datetimelike | ||
# GH12821, iNaT is casted to float | ||
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]: | ||
try: | ||
result = result.astype(dtype) | ||
except Exception: | ||
if dtype.tz: | ||
# convert to datetime and change timezone | ||
from pandas import to_datetime | ||
|
||
result = to_datetime(result).tz_localize("utc") | ||
result = result.tz_convert(dtype.tz) | ||
|
||
elif dtype.type == Period: | ||
# TODO(DatetimeArray): merge with previous elif | ||
from pandas.core.arrays import PeriodArray | ||
|
||
try: | ||
return PeriodArray(result, freq=dtype.freq) | ||
except TypeError: | ||
# e.g. TypeError: int() argument must be a string, a | ||
# bytes-like object or a number, not 'Period | ||
pass | ||
|
||
return result | ||
|
||
|
||
def maybe_downcast_numeric(result, dtype, do_round: bool = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is very similar to to_numeric; would plan as a followup to move to_numeric logic here and call this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good idea |
||
""" | ||
Subset of maybe_downcast_to_dtype restricted to numeric dtypes. | ||
|
||
Parameters | ||
---------- | ||
result : ndarray or ExtensionArray | ||
dtype : np.dtype or ExtensionDtype | ||
do_round : bool | ||
|
||
Returns | ||
------- | ||
ndarray or ExtensionArray | ||
""" | ||
if not isinstance(dtype, np.dtype): | ||
# e.g. SparseDtype has no itemsize attr | ||
return result | ||
|
||
if isinstance(result, list): | ||
# reached via groupoby.agg _ohlc; really this should be handled | ||
# earlier | ||
result = np.array(result) | ||
|
||
def trans(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rather than doing this, I would pass in a callable directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that gives a lot more degrees of freedom to the caller, I'd rather it just be a bool kwarg until/unless we need something more |
||
if do_round: | ||
return x.round() | ||
return x | ||
|
||
if dtype.kind == result.dtype.kind: | ||
# don't allow upcasts here (except if empty) | ||
if dtype.kind == result.dtype.kind: | ||
if result.dtype.itemsize <= dtype.itemsize and np.prod(result.shape): | ||
return result | ||
if result.dtype.itemsize <= dtype.itemsize and result.size: | ||
return result | ||
|
||
if is_bool_dtype(dtype) or is_integer_dtype(dtype): | ||
if is_bool_dtype(dtype) or is_integer_dtype(dtype): | ||
|
||
if not result.size: | ||
# if we don't have any elements, just astype it | ||
if not np.prod(result.shape): | ||
return trans(result).astype(dtype) | ||
return trans(result).astype(dtype) | ||
|
||
# do a test on the first element, if it fails then we are done | ||
r = result.ravel() | ||
arr = np.array([r[0]]) | ||
# do a test on the first element, if it fails then we are done | ||
r = result.ravel() | ||
arr = np.array([r[0]]) | ||
|
||
if isna(arr).any() or not np.allclose(arr, trans(arr).astype(dtype), rtol=0): | ||
# if we have any nulls, then we are done | ||
if isna(arr).any() or not np.allclose( | ||
arr, trans(arr).astype(dtype), rtol=0 | ||
): | ||
return result | ||
return result | ||
|
||
elif not isinstance(r[0], (np.integer, np.floating, np.bool, int, float, bool)): | ||
# a comparable, e.g. a Decimal may slip in here | ||
elif not isinstance( | ||
r[0], (np.integer, np.floating, np.bool, int, float, bool) | ||
): | ||
return result | ||
return result | ||
|
||
if ( | ||
issubclass(result.dtype.type, (np.object_, np.number)) | ||
and notna(result).all() | ||
): | ||
new_result = trans(result).astype(dtype) | ||
try: | ||
if np.allclose(new_result, result, rtol=0): | ||
return new_result | ||
except Exception: | ||
|
||
# comparison of an object dtype with a number type could | ||
# hit here | ||
if (new_result == result).all(): | ||
return new_result | ||
elif issubclass(dtype.type, np.floating) and not is_bool_dtype(result.dtype): | ||
return result.astype(dtype) | ||
|
||
# a datetimelike | ||
# GH12821, iNaT is casted to float | ||
elif dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]: | ||
if ( | ||
issubclass(result.dtype.type, (np.object_, np.number)) | ||
and notna(result).all() | ||
): | ||
new_result = trans(result).astype(dtype) | ||
try: | ||
result = result.astype(dtype) | ||
if np.allclose(new_result, result, rtol=0): | ||
return new_result | ||
except Exception: | ||
if dtype.tz: | ||
# convert to datetime and change timezone | ||
from pandas import to_datetime | ||
|
||
result = to_datetime(result).tz_localize("utc") | ||
result = result.tz_convert(dtype.tz) | ||
|
||
elif dtype.type == Period: | ||
# TODO(DatetimeArray): merge with previous elif | ||
from pandas.core.arrays import PeriodArray | ||
|
||
return PeriodArray(result, freq=dtype.freq) | ||
|
||
except Exception: | ||
pass | ||
# comparison of an object dtype with a number type could | ||
# hit here | ||
if (new_result == result).all(): | ||
return new_result | ||
|
||
elif ( | ||
issubclass(dtype.type, np.floating) | ||
and not is_bool_dtype(result.dtype) | ||
and not is_string_dtype(result.dtype) | ||
): | ||
return result.astype(dtype) | ||
|
||
return result | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about isintance instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually i think the
==
should beis
, will update