Skip to content

Commit b5854c4

Browse files
authored
REF: cast x and bins to Index early in cut, qcut (#54919)
* REF: cast x and bins to Index early in cut, qcut * mypy fixup * troubleshoot builds
1 parent ace97f9 commit b5854c4

File tree

1 file changed

+61
-45
lines changed

1 file changed

+61
-45
lines changed

pandas/core/reshape/tile.py

+61-45
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
to_datetime,
4444
to_timedelta,
4545
)
46-
from pandas.core import nanops
4746
import pandas.core.algorithms as algos
4847

4948
if TYPE_CHECKING:
@@ -243,43 +242,18 @@ def cut(
243242
# NOTE: this binning code is changed a bit from histogram for var(x) == 0
244243

245244
original = x
246-
x = _preprocess_for_cut(x)
247-
x, dtype = _coerce_to_type(x)
245+
x_idx = _preprocess_for_cut(x)
246+
x_idx, dtype = _coerce_to_type(x_idx)
248247

249248
if not np.iterable(bins):
250-
if is_scalar(bins) and bins < 1:
251-
raise ValueError("`bins` should be a positive integer.")
252-
253-
sz = x.size
254-
255-
if sz == 0:
256-
raise ValueError("Cannot cut empty array")
257-
258-
rng = (nanops.nanmin(x), nanops.nanmax(x))
259-
mn, mx = (mi + 0.0 for mi in rng)
260-
261-
if np.isinf(mn) or np.isinf(mx):
262-
# GH 24314
263-
raise ValueError(
264-
"cannot specify integer `bins` when input data contains infinity"
265-
)
266-
if mn == mx: # adjust end points before binning
267-
mn -= 0.001 * abs(mn) if mn != 0 else 0.001
268-
mx += 0.001 * abs(mx) if mx != 0 else 0.001
269-
bins = np.linspace(mn, mx, bins + 1, endpoint=True)
270-
else: # adjust end points after binning
271-
bins = np.linspace(mn, mx, bins + 1, endpoint=True)
272-
adj = (mx - mn) * 0.001 # 0.1% of the range
273-
if right:
274-
bins[0] -= adj
275-
else:
276-
bins[-1] += adj
249+
bins = _nbins_to_bins(x_idx, bins, right)
277250

278251
elif isinstance(bins, IntervalIndex):
279252
if bins.is_overlapping:
280253
raise ValueError("Overlapping IntervalIndex is not accepted.")
281254

282255
else:
256+
bins = Index(bins)
283257
if isinstance(getattr(bins, "dtype", None), DatetimeTZDtype):
284258
bins = np.asarray(bins, dtype=DT64NS_DTYPE)
285259
else:
@@ -289,9 +263,10 @@ def cut(
289263
# GH 26045: cast to float64 to avoid an overflow
290264
if (np.diff(bins.astype("float64")) < 0).any():
291265
raise ValueError("bins must increase monotonically.")
266+
bins = Index(bins)
292267

293268
fac, bins = _bins_to_cuts(
294-
x,
269+
x_idx,
295270
bins,
296271
right=right,
297272
labels=labels,
@@ -367,18 +342,18 @@ def qcut(
367342
array([0, 0, 1, 2, 3])
368343
"""
369344
original = x
370-
x = _preprocess_for_cut(x)
371-
x, dtype = _coerce_to_type(x)
345+
x_idx = _preprocess_for_cut(x)
346+
x_idx, dtype = _coerce_to_type(x_idx)
372347

373348
quantiles = np.linspace(0, 1, q + 1) if is_integer(q) else q
374349

375-
x_np = np.asarray(x)
350+
x_np = np.asarray(x_idx)
376351
x_np = x_np[~np.isnan(x_np)]
377352
bins = np.quantile(x_np, quantiles)
378353

379354
fac, bins = _bins_to_cuts(
380-
x,
381-
bins,
355+
x_idx,
356+
Index(bins),
382357
labels=labels,
383358
precision=precision,
384359
include_lowest=True,
@@ -389,9 +364,44 @@ def qcut(
389364
return _postprocess_for_cut(fac, bins, retbins, dtype, original)
390365

391366

367+
def _nbins_to_bins(x_idx: Index, nbins: int, right: bool) -> Index:
368+
"""
369+
If a user passed an integer N for bins, convert this to a sequence of N
370+
equal(ish)-sized bins.
371+
"""
372+
if is_scalar(nbins) and nbins < 1:
373+
raise ValueError("`bins` should be a positive integer.")
374+
375+
if x_idx.size == 0:
376+
raise ValueError("Cannot cut empty array")
377+
378+
rng = (x_idx.min(), x_idx.max())
379+
mn, mx = rng
380+
381+
if np.isinf(mn) or np.isinf(mx):
382+
# GH#24314
383+
raise ValueError(
384+
"cannot specify integer `bins` when input data contains infinity"
385+
)
386+
387+
if mn == mx: # adjust end points before binning
388+
mn -= 0.001 * abs(mn) if mn != 0 else 0.001
389+
mx += 0.001 * abs(mx) if mx != 0 else 0.001
390+
bins = np.linspace(mn, mx, nbins + 1, endpoint=True)
391+
else: # adjust end points after binning
392+
bins = np.linspace(mn, mx, nbins + 1, endpoint=True)
393+
adj = (mx - mn) * 0.001 # 0.1% of the range
394+
if right:
395+
bins[0] -= adj
396+
else:
397+
bins[-1] += adj
398+
399+
return Index(bins)
400+
401+
392402
def _bins_to_cuts(
393-
x,
394-
bins: np.ndarray,
403+
x: Index,
404+
bins: Index,
395405
right: bool = True,
396406
labels=None,
397407
precision: int = 3,
@@ -408,6 +418,8 @@ def _bins_to_cuts(
408418
"invalid value for 'duplicates' parameter, valid options are: raise, drop"
409419
)
410420

421+
result: Categorical | np.ndarray
422+
411423
if isinstance(bins, IntervalIndex):
412424
# we have a fast-path here
413425
ids = bins.get_indexer(x)
@@ -474,7 +486,7 @@ def _bins_to_cuts(
474486
return result, bins
475487

476488

477-
def _coerce_to_type(x):
489+
def _coerce_to_type(x: Index) -> tuple[Index, DtypeObj | None]:
478490
"""
479491
if the passed data is of datetime/timedelta, bool or nullable int type,
480492
this method converts it to numeric so that cut or qcut method can
@@ -498,11 +510,13 @@ def _coerce_to_type(x):
498510
# https://github.com/pandas-dev/pandas/pull/31290
499511
# https://github.com/pandas-dev/pandas/issues/31389
500512
elif isinstance(x.dtype, ExtensionDtype) and is_numeric_dtype(x.dtype):
501-
x = x.to_numpy(dtype=np.float64, na_value=np.nan)
513+
x_arr = x.to_numpy(dtype=np.float64, na_value=np.nan)
514+
x = Index(x_arr)
502515

503516
if dtype is not None:
504517
# GH 19768: force NaT to NaN during integer conversion
505-
x = np.where(x.notna(), x.view(np.int64), np.nan)
518+
x_arr = np.where(x.notna(), x.view(np.int64), np.nan)
519+
x = Index(x_arr)
506520

507521
return x, dtype
508522

@@ -564,7 +578,7 @@ def _convert_bin_to_datelike_type(bins, dtype: DtypeObj | None):
564578

565579

566580
def _format_labels(
567-
bins,
581+
bins: Index,
568582
precision: int,
569583
right: bool = True,
570584
include_lowest: bool = False,
@@ -597,7 +611,7 @@ def _format_labels(
597611
return IntervalIndex.from_breaks(breaks, closed=closed)
598612

599613

600-
def _preprocess_for_cut(x):
614+
def _preprocess_for_cut(x) -> Index:
601615
"""
602616
handles preprocessing for cut where we convert passed
603617
input to array, strip the index information and store it
@@ -611,7 +625,7 @@ def _preprocess_for_cut(x):
611625
if x.ndim != 1:
612626
raise ValueError("Input array must be 1 dimensional")
613627

614-
return x
628+
return Index(x)
615629

616630

617631
def _postprocess_for_cut(fac, bins, retbins: bool, dtype: DtypeObj | None, original):
@@ -627,6 +641,8 @@ def _postprocess_for_cut(fac, bins, retbins: bool, dtype: DtypeObj | None, origi
627641
return fac
628642

629643
bins = _convert_bin_to_datelike_type(bins, dtype)
644+
if isinstance(bins, Index) and is_numeric_dtype(bins.dtype):
645+
bins = bins._values
630646

631647
return fac, bins
632648

@@ -646,7 +662,7 @@ def _round_frac(x, precision: int):
646662
return np.around(x, digits)
647663

648664

649-
def _infer_precision(base_precision: int, bins) -> int:
665+
def _infer_precision(base_precision: int, bins: Index) -> int:
650666
"""
651667
Infer an appropriate precision for _round_frac
652668
"""

0 commit comments

Comments
 (0)