Skip to content

Commit d801a4b

Browse files
authored
PERF: avoid cast in algos.rank (#46175)
1 parent e70d310 commit d801a4b

File tree

3 files changed

+70
-45
lines changed

3 files changed

+70
-45
lines changed

pandas/_libs/algos.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def is_monotonic(
102102
# ----------------------------------------------------------------------
103103

104104
def rank_1d(
105-
values: np.ndarray, # ndarray[iu_64_floating_obj_t, ndim=1]
105+
values: np.ndarray, # ndarray[numeric_object_t, ndim=1]
106106
labels: np.ndarray | None = ..., # const int64_t[:]=None
107107
is_datetimelike: bool = ...,
108108
ties_method=...,
@@ -111,7 +111,7 @@ def rank_1d(
111111
na_option=...,
112112
) -> np.ndarray: ... # np.ndarray[float64_t, ndim=1]
113113
def rank_2d(
114-
in_arr: np.ndarray, # ndarray[iu_64_floating_obj_t, ndim=2]
114+
in_arr: np.ndarray, # ndarray[numeric_object_t, ndim=2]
115115
axis: int = ...,
116116
is_datetimelike: bool = ...,
117117
ties_method=...,

pandas/_libs/algos.pyx

+68-38
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ cnp.import_array()
4646

4747
cimport pandas._libs.util as util
4848
from pandas._libs.dtypes cimport (
49-
iu_64_floating_obj_t,
5049
numeric_object_t,
5150
numeric_t,
5251
)
@@ -821,30 +820,54 @@ def is_monotonic(ndarray[numeric_object_t, ndim=1] arr, bint timelike):
821820
# rank_1d, rank_2d
822821
# ----------------------------------------------------------------------
823822

824-
cdef iu_64_floating_obj_t get_rank_nan_fill_val(
825-
bint rank_nans_highest,
826-
iu_64_floating_obj_t[:] _=None
823+
cdef numeric_object_t get_rank_nan_fill_val(
824+
bint rank_nans_highest,
825+
numeric_object_t[:] _=None
827826
):
828827
"""
829828
Return the value we'll use to represent missing values when sorting depending
830829
on if we'd like missing values to end up at the top/bottom. (The second parameter
831830
is unused, but needed for fused type specialization)
832831
"""
833832
if rank_nans_highest:
834-
if iu_64_floating_obj_t is object:
833+
if numeric_object_t is object:
835834
return Infinity()
836-
elif iu_64_floating_obj_t is int64_t:
835+
elif numeric_object_t is int64_t:
837836
return util.INT64_MAX
838-
elif iu_64_floating_obj_t is uint64_t:
837+
elif numeric_object_t is int32_t:
838+
return util.INT32_MAX
839+
elif numeric_object_t is int16_t:
840+
return util.INT16_MAX
841+
elif numeric_object_t is int8_t:
842+
return util.INT8_MAX
843+
elif numeric_object_t is uint64_t:
839844
return util.UINT64_MAX
845+
elif numeric_object_t is uint32_t:
846+
return util.UINT32_MAX
847+
elif numeric_object_t is uint16_t:
848+
return util.UINT16_MAX
849+
elif numeric_object_t is uint8_t:
850+
return util.UINT8_MAX
840851
else:
841852
return np.inf
842853
else:
843-
if iu_64_floating_obj_t is object:
854+
if numeric_object_t is object:
844855
return NegInfinity()
845-
elif iu_64_floating_obj_t is int64_t:
856+
elif numeric_object_t is int64_t:
846857
return NPY_NAT
847-
elif iu_64_floating_obj_t is uint64_t:
858+
elif numeric_object_t is int32_t:
859+
return util.INT32_MIN
860+
elif numeric_object_t is int16_t:
861+
return util.INT16_MIN
862+
elif numeric_object_t is int8_t:
863+
return util.INT8_MIN
864+
elif numeric_object_t is uint64_t:
865+
return 0
866+
elif numeric_object_t is uint32_t:
867+
return 0
868+
elif numeric_object_t is uint16_t:
869+
return 0
870+
elif numeric_object_t is uint8_t:
848871
return 0
849872
else:
850873
return -np.inf
@@ -853,7 +876,7 @@ cdef iu_64_floating_obj_t get_rank_nan_fill_val(
853876
@cython.wraparound(False)
854877
@cython.boundscheck(False)
855878
def rank_1d(
856-
ndarray[iu_64_floating_obj_t, ndim=1] values,
879+
ndarray[numeric_object_t, ndim=1] values,
857880
const intp_t[:] labels=None,
858881
bint is_datetimelike=False,
859882
ties_method="average",
@@ -866,7 +889,7 @@ def rank_1d(
866889
867890
Parameters
868891
----------
869-
values : array of iu_64_floating_obj_t values to be ranked
892+
values : array of numeric_object_t values to be ranked
870893
labels : np.ndarray[np.intp] or None
871894
Array containing unique label for each group, with its ordering
872895
matching up to the corresponding record in `values`. If not called
@@ -896,11 +919,11 @@ def rank_1d(
896919
int64_t[::1] grp_sizes
897920
intp_t[:] lexsort_indexer
898921
float64_t[::1] out
899-
ndarray[iu_64_floating_obj_t, ndim=1] masked_vals
900-
iu_64_floating_obj_t[:] masked_vals_memview
922+
ndarray[numeric_object_t, ndim=1] masked_vals
923+
numeric_object_t[:] masked_vals_memview
901924
uint8_t[:] mask
902925
bint keep_na, nans_rank_highest, check_labels, check_mask
903-
iu_64_floating_obj_t nan_fill_val
926+
numeric_object_t nan_fill_val
904927

905928
tiebreak = tiebreakers[ties_method]
906929
if tiebreak == TIEBREAK_FIRST:
@@ -921,22 +944,26 @@ def rank_1d(
921944
check_labels = labels is not None
922945

923946
# For cases where a mask is not possible, we can avoid mask checks
924-
check_mask = not (iu_64_floating_obj_t is uint64_t or
925-
(iu_64_floating_obj_t is int64_t and not is_datetimelike))
947+
check_mask = (
948+
numeric_object_t is float32_t
949+
or numeric_object_t is float64_t
950+
or numeric_object_t is object
951+
or (numeric_object_t is int64_t and is_datetimelike)
952+
)
926953

927954
# Copy values into new array in order to fill missing data
928955
# with mask, without obfuscating location of missing data
929956
# in values array
930-
if iu_64_floating_obj_t is object and values.dtype != np.object_:
957+
if numeric_object_t is object and values.dtype != np.object_:
931958
masked_vals = values.astype('O')
932959
else:
933960
masked_vals = values.copy()
934961

935-
if iu_64_floating_obj_t is object:
962+
if numeric_object_t is object:
936963
mask = missing.isnaobj(masked_vals)
937-
elif iu_64_floating_obj_t is int64_t and is_datetimelike:
964+
elif numeric_object_t is int64_t and is_datetimelike:
938965
mask = (masked_vals == NPY_NAT).astype(np.uint8)
939-
elif iu_64_floating_obj_t is float64_t:
966+
elif numeric_object_t is float64_t or numeric_object_t is float32_t:
940967
mask = np.isnan(masked_vals).astype(np.uint8)
941968
else:
942969
mask = np.zeros(shape=len(masked_vals), dtype=np.uint8)
@@ -948,7 +975,7 @@ def rank_1d(
948975
# will flip the ordering to still end up with lowest rank.
949976
# Symmetric logic applies to `na_option == 'bottom'`
950977
nans_rank_highest = ascending ^ (na_option == 'top')
951-
nan_fill_val = get_rank_nan_fill_val[iu_64_floating_obj_t](nans_rank_highest)
978+
nan_fill_val = get_rank_nan_fill_val[numeric_object_t](nans_rank_highest)
952979
if nans_rank_highest:
953980
order = [masked_vals, mask]
954981
else:
@@ -994,8 +1021,8 @@ cdef void rank_sorted_1d(
9941021
float64_t[::1] out,
9951022
int64_t[::1] grp_sizes,
9961023
const intp_t[:] sort_indexer,
997-
# Can make const with cython3 (https://github.com/cython/cython/issues/3222)
998-
iu_64_floating_obj_t[:] masked_vals,
1024+
# TODO(cython3): make const (https://github.com/cython/cython/issues/3222)
1025+
numeric_object_t[:] masked_vals,
9991026
const uint8_t[:] mask,
10001027
bint check_mask,
10011028
Py_ssize_t N,
@@ -1019,7 +1046,7 @@ cdef void rank_sorted_1d(
10191046
if labels is None.
10201047
sort_indexer : intp_t[:]
10211048
Array of indices which sorts masked_vals
1022-
masked_vals : iu_64_floating_obj_t[:]
1049+
masked_vals : numeric_object_t[:]
10231050
The values input to rank_1d, with missing values replaced by fill values
10241051
mask : uint8_t[:]
10251052
Array where entries are True if the value is missing, False otherwise.
@@ -1051,7 +1078,7 @@ cdef void rank_sorted_1d(
10511078
# that sorted value for retrieval back from the original
10521079
# values / masked_vals arrays
10531080
# TODO(cython3): de-duplicate once cython supports conditional nogil
1054-
if iu_64_floating_obj_t is object:
1081+
if numeric_object_t is object:
10551082
with gil:
10561083
for i in range(N):
10571084
at_end = i == N - 1
@@ -1259,7 +1286,7 @@ cdef void rank_sorted_1d(
12591286

12601287

12611288
def rank_2d(
1262-
ndarray[iu_64_floating_obj_t, ndim=2] in_arr,
1289+
ndarray[numeric_object_t, ndim=2] in_arr,
12631290
int axis=0,
12641291
bint is_datetimelike=False,
12651292
ties_method="average",
@@ -1274,13 +1301,13 @@ def rank_2d(
12741301
Py_ssize_t k, n, col
12751302
float64_t[::1, :] out # Column-major so columns are contiguous
12761303
int64_t[::1] grp_sizes
1277-
ndarray[iu_64_floating_obj_t, ndim=2] values
1278-
iu_64_floating_obj_t[:, :] masked_vals
1304+
ndarray[numeric_object_t, ndim=2] values
1305+
numeric_object_t[:, :] masked_vals
12791306
intp_t[:, :] sort_indexer
12801307
uint8_t[:, :] mask
12811308
TiebreakEnumType tiebreak
12821309
bint check_mask, keep_na, nans_rank_highest
1283-
iu_64_floating_obj_t nan_fill_val
1310+
numeric_object_t nan_fill_val
12841311

12851312
tiebreak = tiebreakers[ties_method]
12861313
if tiebreak == TIEBREAK_FIRST:
@@ -1290,29 +1317,32 @@ def rank_2d(
12901317
keep_na = na_option == 'keep'
12911318

12921319
# For cases where a mask is not possible, we can avoid mask checks
1293-
check_mask = not (iu_64_floating_obj_t is uint64_t or
1294-
(iu_64_floating_obj_t is int64_t and not is_datetimelike))
1320+
check_mask = (
1321+
numeric_object_t is float32_t
1322+
or numeric_object_t is float64_t
1323+
or numeric_object_t is object
1324+
or (numeric_object_t is int64_t and is_datetimelike)
1325+
)
12951326

12961327
if axis == 1:
12971328
values = np.asarray(in_arr).T.copy()
12981329
else:
12991330
values = np.asarray(in_arr).copy()
13001331

1301-
if iu_64_floating_obj_t is object:
1332+
if numeric_object_t is object:
13021333
if values.dtype != np.object_:
13031334
values = values.astype('O')
13041335

13051336
nans_rank_highest = ascending ^ (na_option == 'top')
13061337
if check_mask:
1307-
nan_fill_val = get_rank_nan_fill_val[iu_64_floating_obj_t](nans_rank_highest)
1338+
nan_fill_val = get_rank_nan_fill_val[numeric_object_t](nans_rank_highest)
13081339

1309-
if iu_64_floating_obj_t is object:
1340+
if numeric_object_t is object:
13101341
mask = missing.isnaobj2d(values).view(np.uint8)
1311-
elif iu_64_floating_obj_t is float64_t:
1342+
elif numeric_object_t is float64_t or numeric_object_t is float32_t:
13121343
mask = np.isnan(values).view(np.uint8)
1313-
1314-
# int64 and datetimelike
13151344
else:
1345+
# i.e. int64 and datetimelike
13161346
mask = (values == NPY_NAT).view(np.uint8)
13171347
np.putmask(values, mask, nan_fill_val)
13181348
else:

pandas/core/algorithms.py

-5
Original file line numberDiff line numberDiff line change
@@ -1003,11 +1003,6 @@ def rank(
10031003
is_datetimelike = needs_i8_conversion(values.dtype)
10041004
values = _ensure_data(values)
10051005

1006-
if values.dtype.kind in ["i", "u", "f"]:
1007-
# rank_t includes only object, int64, uint64, float64
1008-
dtype = values.dtype.kind + "8"
1009-
values = values.astype(dtype, copy=False)
1010-
10111006
if values.ndim == 1:
10121007
ranks = algos.rank_1d(
10131008
values,

0 commit comments

Comments
 (0)