Skip to content

REF: simplify core.algorithms, reshape.cut #29385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 4, 2019
36 changes: 13 additions & 23 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from pandas._libs import algos, hashtable as htable, lib
from pandas._libs import Timestamp, algos, hashtable as htable, lib
from pandas._libs.tslib import iNaT
from pandas.util._decorators import Appender, Substitution, deprecate_kwarg

Expand Down Expand Up @@ -1440,7 +1440,9 @@ def _take_nd_object(arr, indexer, out, axis: int, fill_value, mask_info):
}


def _get_take_nd_function(ndim, arr_dtype, out_dtype, axis: int = 0, mask_info=None):
def _get_take_nd_function(
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None
):
if ndim <= 2:
tup = (arr_dtype.name, out_dtype.name)
if ndim == 1:
Expand Down Expand Up @@ -1474,7 +1476,7 @@ def func2(arr, indexer, out, fill_value=np.nan):
return func2


def take(arr, indices, axis=0, allow_fill: bool = False, fill_value=None):
def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None):
"""
Take elements from an array.

Expand Down Expand Up @@ -1568,13 +1570,7 @@ def take(arr, indices, axis=0, allow_fill: bool = False, fill_value=None):


def take_nd(
arr,
indexer,
axis=0,
out=None,
fill_value=np.nan,
mask_info=None,
allow_fill: bool = True,
arr, indexer, axis: int = 0, out=None, fill_value=np.nan, allow_fill: bool = True
):
"""
Specialized Cython take which sets NaN values in one pass
Expand All @@ -1597,10 +1593,6 @@ def take_nd(
maybe_promote to determine this type for any fill_value
fill_value : any, default np.nan
Fill value to replace -1 values with
mask_info : tuple of (ndarray, boolean)
If provided, value should correspond to:
(indexer != -1, (indexer != -1).any())
If not provided, it will be computed internally if necessary
allow_fill : boolean, default True
If False, indexer is assumed to contain no -1 values so no filling
will be done. This short-circuits computation of a mask. Result is
Expand All @@ -1611,6 +1603,7 @@ def take_nd(
subarray : array-like
May be the same type as the input, or cast to an ndarray.
"""
mask_info = None

if is_extension_array_dtype(arr):
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
Expand All @@ -1632,12 +1625,9 @@ def take_nd(
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
if dtype != arr.dtype and (out is None or out.dtype != dtype):
# check if promotion is actually required based on indexer
if mask_info is not None:
mask, needs_masking = mask_info
else:
mask = indexer == -1
needs_masking = mask.any()
mask_info = mask, needs_masking
mask = indexer == -1
needs_masking = mask.any()
mask_info = mask, needs_masking
if needs_masking:
if out is not None and out.dtype != dtype:
raise TypeError("Incompatible type for fill_value")
Expand Down Expand Up @@ -1818,12 +1808,12 @@ def searchsorted(arr, value, side="left", sorter=None):
elif not (
is_object_dtype(arr) or is_numeric_dtype(arr) or is_categorical_dtype(arr)
):
from pandas.core.series import Series

# E.g. if `arr` is an array with dtype='datetime64[ns]'
# and `value` is a pd.Timestamp, we may need to convert value
value_ser = Series(value)._values
value_ser = array([value]) if is_scalar(value) else array(value)
value = value_ser[0] if is_scalar(value) else value_ser
if isinstance(value, Timestamp) and value.tzinfo is None:
value = value.to_datetime64()

result = arr.searchsorted(value, side=side, sorter=sorter)
return result
Expand Down
44 changes: 13 additions & 31 deletions pandas/core/reshape/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from pandas._libs import Timedelta, Timestamp
from pandas._libs.interval import Interval
from pandas._libs.lib import infer_dtype

from pandas.core.dtypes.common import (
Expand All @@ -18,17 +19,10 @@
is_scalar,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import isna

from pandas import (
Categorical,
Index,
Interval,
IntervalIndex,
Series,
to_datetime,
to_timedelta,
)
from pandas import Categorical, Index, IntervalIndex, to_datetime, to_timedelta
import pandas.core.algorithms as algos
import pandas.core.nanops as nanops

Expand Down Expand Up @@ -206,7 +200,8 @@ def cut(
# NOTE: this binning code is changed a bit from histogram for var(x) == 0

# for handling the cut for datetime and timedelta objects
x_is_series, series_index, name, x = _preprocess_for_cut(x)
original = x
x = _preprocess_for_cut(x)
x, dtype = _coerce_to_type(x)

if not np.iterable(bins):
Expand Down Expand Up @@ -268,9 +263,7 @@ def cut(
duplicates=duplicates,
)

return _postprocess_for_cut(
fac, bins, retbins, x_is_series, series_index, name, dtype
)
return _postprocess_for_cut(fac, bins, retbins, dtype, original)


def qcut(
Expand Down Expand Up @@ -333,8 +326,8 @@ def qcut(
>>> pd.qcut(range(5), 4, labels=False)
array([0, 0, 1, 2, 3])
"""
x_is_series, series_index, name, x = _preprocess_for_cut(x)

original = x
x = _preprocess_for_cut(x)
x, dtype = _coerce_to_type(x)

if is_integer(q):
Expand All @@ -352,9 +345,7 @@ def qcut(
duplicates=duplicates,
)

return _postprocess_for_cut(
fac, bins, retbins, x_is_series, series_index, name, dtype
)
return _postprocess_for_cut(fac, bins, retbins, dtype, original)


def _bins_to_cuts(
Expand Down Expand Up @@ -544,13 +535,6 @@ def _preprocess_for_cut(x):
input to array, strip the index information and store it
separately
"""
x_is_series = isinstance(x, Series)
series_index = None
name = None

if x_is_series:
series_index = x.index
name = x.name

# Check that the passed array is a Pandas or Numpy object
# We don't want to strip away a Pandas data-type here (e.g. datetimetz)
Expand All @@ -560,19 +544,17 @@ def _preprocess_for_cut(x):
if x.ndim != 1:
raise ValueError("Input array must be 1 dimensional")

return x_is_series, series_index, name, x
return x


def _postprocess_for_cut(
fac, bins, retbins: bool, x_is_series, series_index, name, dtype
):
def _postprocess_for_cut(fac, bins, retbins: bool, dtype, original):
"""
handles post processing for the cut method where
we combine the index information if the originally passed
datatype was a series
"""
if x_is_series:
fac = Series(fac, index=series_index, name=name)
if isinstance(original, ABCSeries):
fac = original._constructor(fac, index=original.index, name=original.name)

if not retbins:
return fac
Expand Down