Skip to content

Commit e6ec3b2

Browse files
committed
BUG: Handle IntegerArray in pd.cut
xref pandas-dev#30944. I think this doesn't close it, since only the pd.cut compoment is fixed.
1 parent 79f3b4c commit e6ec3b2

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

pandas/core/reshape/tile.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
is_datetime64_dtype,
1515
is_datetime64tz_dtype,
1616
is_datetime_or_timedelta_dtype,
17+
is_extension_array_dtype,
1718
is_integer,
19+
is_integer_dtype,
1820
is_list_like,
1921
is_scalar,
2022
is_timedelta64_dtype,
@@ -209,16 +211,28 @@ def cut(
209211
if is_scalar(bins) and bins < 1:
210212
raise ValueError("`bins` should be a positive integer.")
211213

212-
try: # for array-like
213-
sz = x.size
214+
# TODO: Support arbitrary Extension Arrays. We need
215+
# For now, we're only attempting to support IntegerArray.
216+
# See the note on _bins_to_cuts about what is needed.
217+
is_nullable_integer = is_extension_array_dtype(x.dtype) and is_integer_dtype(
218+
x.dtype
219+
)
220+
try:
221+
if is_extension_array_dtype(x) and is_integer_dtype(x):
222+
sz = len(x)
223+
else:
224+
sz = x.size
214225
except AttributeError:
215226
x = np.asarray(x)
216227
sz = x.size
217228

218229
if sz == 0:
219230
raise ValueError("Cannot cut empty array")
220231

221-
rng = (nanops.nanmin(x), nanops.nanmax(x))
232+
if is_nullable_integer:
233+
rng = x._reduce("min"), x._reduce("max")
234+
else:
235+
rng = (nanops.nanmin(x), nanops.nanmax(x))
222236
mn, mx = [mi + 0.0 for mi in rng]
223237

224238
if np.isinf(mn) or np.isinf(mx):
@@ -383,10 +397,26 @@ def _bins_to_cuts(
383397
bins = unique_bins
384398

385399
side = "left" if right else "right"
386-
ids = ensure_int64(bins.searchsorted(x, side=side))
400+
is_nullable_integer = is_extension_array_dtype(x.dtype) and is_integer_dtype(
401+
x.dtype
402+
)
403+
404+
if is_nullable_integer:
405+
# TODO: Support other extension types somehow. We don't currently
406+
# We *could* use factorize here, but that does more that we need.
407+
# We just need some integer representation, and the NA values needn't
408+
# even be marked specially.
409+
x_int = x._ndarray_values
410+
ids = ensure_int64(bins.searchsorted(x_int, side=side))
411+
else:
412+
ids = ensure_int64(bins.searchsorted(x, side=side))
387413

388414
if include_lowest:
389-
ids[x == bins[0]] = 1
415+
mask = x == bins[0]
416+
if is_nullable_integer:
417+
# when x is integer
418+
mask = mask.to_numpy(na_value=False, dtype=bool)
419+
ids[mask] = 1
390420

391421
na_mask = isna(x) | (ids == len(bins)) | (ids == 0)
392422
has_nas = na_mask.any()

pandas/tests/arrays/test_integer.py

+14
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,20 @@ def test_value_counts_na():
10611061
tm.assert_series_equal(result, expected)
10621062

10631063

1064+
@pytest.mark.parametrize("bins", [3, [0, 5, 15]])
1065+
@pytest.mark.parametrize("right", [True, False])
1066+
@pytest.mark.parametrize("include_lowest", [True, False])
1067+
def test_cut(bins, right, include_lowest):
1068+
a = np.random.randint(0, 10, size=50).astype(float)
1069+
a[::2] = np.nan
1070+
tm.assert_categorical_equal(
1071+
pd.cut(
1072+
pd.array(a, dtype="Int64"), bins, right=right, include_lowest=include_lowest
1073+
),
1074+
pd.cut(a, bins, right=right, include_lowest=include_lowest),
1075+
)
1076+
1077+
10641078
# TODO(jreback) - these need testing / are broken
10651079

10661080
# shift

0 commit comments

Comments
 (0)