Skip to content

Commit 0d977e9

Browse files
jbrockmendeljreback
authored andcommitted
REF: simplify core.algorithms, reshape.cut (#29385)
1 parent 6cc8234 commit 0d977e9

File tree

2 files changed

+26
-54
lines changed

2 files changed

+26
-54
lines changed

pandas/core/algorithms.py

+13-23
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from pandas._libs import algos, hashtable as htable, lib
11+
from pandas._libs import Timestamp, algos, hashtable as htable, lib
1212
from pandas._libs.tslib import iNaT
1313
from pandas.util._decorators import Appender, Substitution, deprecate_kwarg
1414

@@ -1440,7 +1440,9 @@ def _take_nd_object(arr, indexer, out, axis: int, fill_value, mask_info):
14401440
}
14411441

14421442

1443-
def _get_take_nd_function(ndim, arr_dtype, out_dtype, axis: int = 0, mask_info=None):
1443+
def _get_take_nd_function(
1444+
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None
1445+
):
14441446
if ndim <= 2:
14451447
tup = (arr_dtype.name, out_dtype.name)
14461448
if ndim == 1:
@@ -1474,7 +1476,7 @@ def func2(arr, indexer, out, fill_value=np.nan):
14741476
return func2
14751477

14761478

1477-
def take(arr, indices, axis=0, allow_fill: bool = False, fill_value=None):
1479+
def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None):
14781480
"""
14791481
Take elements from an array.
14801482
@@ -1568,13 +1570,7 @@ def take(arr, indices, axis=0, allow_fill: bool = False, fill_value=None):
15681570

15691571

15701572
def take_nd(
1571-
arr,
1572-
indexer,
1573-
axis=0,
1574-
out=None,
1575-
fill_value=np.nan,
1576-
mask_info=None,
1577-
allow_fill: bool = True,
1573+
arr, indexer, axis: int = 0, out=None, fill_value=np.nan, allow_fill: bool = True
15781574
):
15791575
"""
15801576
Specialized Cython take which sets NaN values in one pass
@@ -1597,10 +1593,6 @@ def take_nd(
15971593
maybe_promote to determine this type for any fill_value
15981594
fill_value : any, default np.nan
15991595
Fill value to replace -1 values with
1600-
mask_info : tuple of (ndarray, boolean)
1601-
If provided, value should correspond to:
1602-
(indexer != -1, (indexer != -1).any())
1603-
If not provided, it will be computed internally if necessary
16041596
allow_fill : boolean, default True
16051597
If False, indexer is assumed to contain no -1 values so no filling
16061598
will be done. This short-circuits computation of a mask. Result is
@@ -1611,6 +1603,7 @@ def take_nd(
16111603
subarray : array-like
16121604
May be the same type as the input, or cast to an ndarray.
16131605
"""
1606+
mask_info = None
16141607

16151608
if is_extension_array_dtype(arr):
16161609
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
@@ -1632,12 +1625,9 @@ def take_nd(
16321625
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
16331626
if dtype != arr.dtype and (out is None or out.dtype != dtype):
16341627
# check if promotion is actually required based on indexer
1635-
if mask_info is not None:
1636-
mask, needs_masking = mask_info
1637-
else:
1638-
mask = indexer == -1
1639-
needs_masking = mask.any()
1640-
mask_info = mask, needs_masking
1628+
mask = indexer == -1
1629+
needs_masking = mask.any()
1630+
mask_info = mask, needs_masking
16411631
if needs_masking:
16421632
if out is not None and out.dtype != dtype:
16431633
raise TypeError("Incompatible type for fill_value")
@@ -1818,12 +1808,12 @@ def searchsorted(arr, value, side="left", sorter=None):
18181808
elif not (
18191809
is_object_dtype(arr) or is_numeric_dtype(arr) or is_categorical_dtype(arr)
18201810
):
1821-
from pandas.core.series import Series
1822-
18231811
# E.g. if `arr` is an array with dtype='datetime64[ns]'
18241812
# and `value` is a pd.Timestamp, we may need to convert value
1825-
value_ser = Series(value)._values
1813+
value_ser = array([value]) if is_scalar(value) else array(value)
18261814
value = value_ser[0] if is_scalar(value) else value_ser
1815+
if isinstance(value, Timestamp) and value.tzinfo is None:
1816+
value = value.to_datetime64()
18271817

18281818
result = arr.searchsorted(value, side=side, sorter=sorter)
18291819
return result

pandas/core/reshape/tile.py

+13-31
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from pandas._libs import Timedelta, Timestamp
7+
from pandas._libs.interval import Interval
78
from pandas._libs.lib import infer_dtype
89

910
from pandas.core.dtypes.common import (
@@ -18,17 +19,10 @@
1819
is_scalar,
1920
is_timedelta64_dtype,
2021
)
22+
from pandas.core.dtypes.generic import ABCSeries
2123
from pandas.core.dtypes.missing import isna
2224

23-
from pandas import (
24-
Categorical,
25-
Index,
26-
Interval,
27-
IntervalIndex,
28-
Series,
29-
to_datetime,
30-
to_timedelta,
31-
)
25+
from pandas import Categorical, Index, IntervalIndex, to_datetime, to_timedelta
3226
import pandas.core.algorithms as algos
3327
import pandas.core.nanops as nanops
3428

@@ -206,7 +200,8 @@ def cut(
206200
# NOTE: this binning code is changed a bit from histogram for var(x) == 0
207201

208202
# for handling the cut for datetime and timedelta objects
209-
x_is_series, series_index, name, x = _preprocess_for_cut(x)
203+
original = x
204+
x = _preprocess_for_cut(x)
210205
x, dtype = _coerce_to_type(x)
211206

212207
if not np.iterable(bins):
@@ -268,9 +263,7 @@ def cut(
268263
duplicates=duplicates,
269264
)
270265

271-
return _postprocess_for_cut(
272-
fac, bins, retbins, x_is_series, series_index, name, dtype
273-
)
266+
return _postprocess_for_cut(fac, bins, retbins, dtype, original)
274267

275268

276269
def qcut(
@@ -333,8 +326,8 @@ def qcut(
333326
>>> pd.qcut(range(5), 4, labels=False)
334327
array([0, 0, 1, 2, 3])
335328
"""
336-
x_is_series, series_index, name, x = _preprocess_for_cut(x)
337-
329+
original = x
330+
x = _preprocess_for_cut(x)
338331
x, dtype = _coerce_to_type(x)
339332

340333
if is_integer(q):
@@ -352,9 +345,7 @@ def qcut(
352345
duplicates=duplicates,
353346
)
354347

355-
return _postprocess_for_cut(
356-
fac, bins, retbins, x_is_series, series_index, name, dtype
357-
)
348+
return _postprocess_for_cut(fac, bins, retbins, dtype, original)
358349

359350

360351
def _bins_to_cuts(
@@ -544,13 +535,6 @@ def _preprocess_for_cut(x):
544535
input to array, strip the index information and store it
545536
separately
546537
"""
547-
x_is_series = isinstance(x, Series)
548-
series_index = None
549-
name = None
550-
551-
if x_is_series:
552-
series_index = x.index
553-
name = x.name
554538

555539
# Check that the passed array is a Pandas or Numpy object
556540
# We don't want to strip away a Pandas data-type here (e.g. datetimetz)
@@ -560,19 +544,17 @@ def _preprocess_for_cut(x):
560544
if x.ndim != 1:
561545
raise ValueError("Input array must be 1 dimensional")
562546

563-
return x_is_series, series_index, name, x
547+
return x
564548

565549

566-
def _postprocess_for_cut(
567-
fac, bins, retbins: bool, x_is_series, series_index, name, dtype
568-
):
550+
def _postprocess_for_cut(fac, bins, retbins: bool, dtype, original):
569551
"""
570552
handles post processing for the cut method where
571553
we combine the index information if the originally passed
572554
datatype was a series
573555
"""
574-
if x_is_series:
575-
fac = Series(fac, index=series_index, name=name)
556+
if isinstance(original, ABCSeries):
557+
fac = original._constructor(fac, index=original.index, name=original.name)
576558

577559
if not retbins:
578560
return fac

0 commit comments

Comments
 (0)