|
14 | 14 | is_datetime64_dtype,
|
15 | 15 | is_datetime64tz_dtype,
|
16 | 16 | is_datetime_or_timedelta_dtype,
|
| 17 | + is_extension_array_dtype, |
17 | 18 | is_integer,
|
| 19 | + is_integer_dtype, |
18 | 20 | is_list_like,
|
19 | 21 | is_scalar,
|
20 | 22 | is_timedelta64_dtype,
|
@@ -209,16 +211,28 @@ def cut(
|
209 | 211 | if is_scalar(bins) and bins < 1:
|
210 | 212 | raise ValueError("`bins` should be a positive integer.")
|
211 | 213 |
|
212 |
| - try: # for array-like |
213 |
| - sz = x.size |
| 214 | + # TODO: Support arbitrary Extension Arrays. We need |
| 215 | + # For now, we're only attempting to support IntegerArray. |
| 216 | + # See the note on _bins_to_cuts about what is needed. |
| 217 | + is_nullable_integer = is_extension_array_dtype(x.dtype) and is_integer_dtype( |
| 218 | + x.dtype |
| 219 | + ) |
| 220 | + try: |
| 221 | + if is_extension_array_dtype(x) and is_integer_dtype(x): |
| 222 | + sz = len(x) |
| 223 | + else: |
| 224 | + sz = x.size |
214 | 225 | except AttributeError:
|
215 | 226 | x = np.asarray(x)
|
216 | 227 | sz = x.size
|
217 | 228 |
|
218 | 229 | if sz == 0:
|
219 | 230 | raise ValueError("Cannot cut empty array")
|
220 | 231 |
|
221 |
| - rng = (nanops.nanmin(x), nanops.nanmax(x)) |
| 232 | + if is_nullable_integer: |
| 233 | + rng = x._reduce("min"), x._reduce("max") |
| 234 | + else: |
| 235 | + rng = (nanops.nanmin(x), nanops.nanmax(x)) |
222 | 236 | mn, mx = [mi + 0.0 for mi in rng]
|
223 | 237 |
|
224 | 238 | if np.isinf(mn) or np.isinf(mx):
|
@@ -383,10 +397,26 @@ def _bins_to_cuts(
|
383 | 397 | bins = unique_bins
|
384 | 398 |
|
385 | 399 | side = "left" if right else "right"
|
386 |
| - ids = ensure_int64(bins.searchsorted(x, side=side)) |
| 400 | + is_nullable_integer = is_extension_array_dtype(x.dtype) and is_integer_dtype( |
| 401 | + x.dtype |
| 402 | + ) |
| 403 | + |
| 404 | + if is_nullable_integer: |
| 405 | + # TODO: Support other extension types somehow. We don't currently |
| 406 | + # We *could* use factorize here, but that does more that we need. |
| 407 | + # We just need some integer representation, and the NA values needn't |
| 408 | + # even be marked specially. |
| 409 | + x_int = x._ndarray_values |
| 410 | + ids = ensure_int64(bins.searchsorted(x_int, side=side)) |
| 411 | + else: |
| 412 | + ids = ensure_int64(bins.searchsorted(x, side=side)) |
387 | 413 |
|
388 | 414 | if include_lowest:
|
389 |
| - ids[x == bins[0]] = 1 |
| 415 | + mask = x == bins[0] |
| 416 | + if is_nullable_integer: |
| 417 | + # when x is integer |
| 418 | + mask = mask.to_numpy(na_value=False, dtype=bool) |
| 419 | + ids[mask] = 1 |
390 | 420 |
|
391 | 421 | na_mask = isna(x) | (ids == len(bins)) | (ids == 0)
|
392 | 422 | has_nas = na_mask.any()
|
|
0 commit comments