-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG: Handle IntegerArray in pd.cut #31290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,9 @@ | |
is_datetime64_dtype, | ||
is_datetime64tz_dtype, | ||
is_datetime_or_timedelta_dtype, | ||
is_extension_array_dtype, | ||
is_integer, | ||
is_integer_dtype, | ||
is_list_like, | ||
is_scalar, | ||
is_timedelta64_dtype, | ||
|
@@ -209,16 +211,28 @@ def cut( | |
if is_scalar(bins) and bins < 1: | ||
raise ValueError("`bins` should be a positive integer.") | ||
|
||
try: # for array-like | ||
sz = x.size | ||
# TODO: Support arbitrary Extension Arrays. We need | ||
# For now, we're only attempting to support IntegerArray. | ||
# See the note on _bins_to_cuts about what is needed. | ||
is_nullable_integer = is_extension_array_dtype(x.dtype) and is_integer_dtype( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't is_integer_dtype suffice here? |
||
x.dtype | ||
) | ||
try: | ||
if is_extension_array_dtype(x) and is_integer_dtype(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just do len(x)? |
||
sz = len(x) | ||
else: | ||
sz = x.size | ||
except AttributeError: | ||
x = np.asarray(x) | ||
sz = x.size | ||
|
||
if sz == 0: | ||
raise ValueError("Cannot cut empty array") | ||
|
||
rng = (nanops.nanmin(x), nanops.nanmax(x)) | ||
if is_nullable_integer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IntegerArray doesn't have a min / max yet. |
||
rng = x._reduce("min"), x._reduce("max") | ||
else: | ||
rng = (nanops.nanmin(x), nanops.nanmax(x)) | ||
mn, mx = [mi + 0.0 for mi in rng] | ||
|
||
if np.isinf(mn) or np.isinf(mx): | ||
|
@@ -383,10 +397,26 @@ def _bins_to_cuts( | |
bins = unique_bins | ||
|
||
side = "left" if right else "right" | ||
ids = ensure_int64(bins.searchsorted(x, side=side)) | ||
is_nullable_integer = is_extension_array_dtype(x.dtype) and is_integer_dtype( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above |
||
x.dtype | ||
) | ||
|
||
if is_nullable_integer: | ||
# TODO: Support other extension types somehow. We don't currently | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit interesting. We need to get integers for searchsorted and it doesn't really matter how the NA values are encoded since we mask them out later on. I don't think we have anything in the interface like this right now. The closest is
Which is more work than we need here. Worth thinking about for the future. |
||
# We *could* use factorize here, but that does more that we need. | ||
# We just need some integer representation, and the NA values needn't | ||
# even be marked specially. | ||
x_int = x._ndarray_values | ||
ids = ensure_int64(bins.searchsorted(x_int, side=side)) | ||
else: | ||
ids = ensure_int64(bins.searchsorted(x, side=side)) | ||
|
||
if include_lowest: | ||
ids[x == bins[0]] = 1 | ||
mask = x == bins[0] | ||
if is_nullable_integer: | ||
# when x is integer | ||
mask = mask.to_numpy(na_value=False, dtype=bool) | ||
ids[mask] = 1 | ||
|
||
na_mask = isna(x) | (ids == len(bins)) | (ids == 0) | ||
has_nas = na_mask.any() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.