Skip to content

Commit 56a5f5c

Browse files
authored
CLN: share rank_t fused type (#43789)
1 parent 37f0363 commit 56a5f5c

File tree

3 files changed

+78
-69
lines changed

3 files changed

+78
-69
lines changed

pandas/_libs/algos.pyx

+36-37
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ from numpy cimport (
4545
cnp.import_array()
4646

4747
cimport pandas._libs.util as util
48+
from pandas._libs.dtypes cimport numeric_object_t
4849
from pandas._libs.khash cimport (
4950
kh_destroy_int64,
5051
kh_get_int64,
@@ -860,34 +861,30 @@ def is_monotonic(ndarray[algos_t, ndim=1] arr, bint timelike):
860861
# rank_1d, rank_2d
861862
# ----------------------------------------------------------------------
862863

863-
ctypedef fused rank_t:
864-
object
865-
float64_t
866-
uint64_t
867-
int64_t
868-
869-
870-
cdef rank_t get_rank_nan_fill_val(bint rank_nans_highest, rank_t[:] _=None):
864+
cdef numeric_object_t get_rank_nan_fill_val(
865+
bint rank_nans_highest,
866+
numeric_object_t[:] _=None
867+
):
871868
"""
872869
Return the value we'll use to represent missing values when sorting depending
873870
on if we'd like missing values to end up at the top/bottom. (The second parameter
874871
is unused, but needed for fused type specialization)
875872
"""
876873
if rank_nans_highest:
877-
if rank_t is object:
874+
if numeric_object_t is object:
878875
return Infinity()
879-
elif rank_t is int64_t:
876+
elif numeric_object_t is int64_t:
880877
return util.INT64_MAX
881-
elif rank_t is uint64_t:
878+
elif numeric_object_t is uint64_t:
882879
return util.UINT64_MAX
883880
else:
884881
return np.inf
885882
else:
886-
if rank_t is object:
883+
if numeric_object_t is object:
887884
return NegInfinity()
888-
elif rank_t is int64_t:
885+
elif numeric_object_t is int64_t:
889886
return NPY_NAT
890-
elif rank_t is uint64_t:
887+
elif numeric_object_t is uint64_t:
891888
return 0
892889
else:
893890
return -np.inf
@@ -896,7 +893,7 @@ cdef rank_t get_rank_nan_fill_val(bint rank_nans_highest, rank_t[:] _=None):
896893
@cython.wraparound(False)
897894
@cython.boundscheck(False)
898895
def rank_1d(
899-
ndarray[rank_t, ndim=1] values,
896+
ndarray[numeric_object_t, ndim=1] values,
900897
const intp_t[:] labels=None,
901898
bint is_datetimelike=False,
902899
ties_method="average",
@@ -909,7 +906,7 @@ def rank_1d(
909906
910907
Parameters
911908
----------
912-
values : array of rank_t values to be ranked
909+
values : array of numeric_object_t values to be ranked
913910
labels : np.ndarray[np.intp] or None
914911
Array containing unique label for each group, with its ordering
915912
matching up to the corresponding record in `values`. If not called
@@ -939,11 +936,11 @@ def rank_1d(
939936
int64_t[::1] grp_sizes
940937
intp_t[:] lexsort_indexer
941938
float64_t[::1] out
942-
ndarray[rank_t, ndim=1] masked_vals
943-
rank_t[:] masked_vals_memview
939+
ndarray[numeric_object_t, ndim=1] masked_vals
940+
numeric_object_t[:] masked_vals_memview
944941
uint8_t[:] mask
945942
bint keep_na, nans_rank_highest, check_labels, check_mask
946-
rank_t nan_fill_val
943+
numeric_object_t nan_fill_val
947944

948945
tiebreak = tiebreakers[ties_method]
949946
if tiebreak == TIEBREAK_FIRST:
@@ -964,21 +961,22 @@ def rank_1d(
964961
check_labels = labels is not None
965962

966963
# For cases where a mask is not possible, we can avoid mask checks
967-
check_mask = not (rank_t is uint64_t or (rank_t is int64_t and not is_datetimelike))
964+
check_mask = not (numeric_object_t is uint64_t or
965+
(numeric_object_t is int64_t and not is_datetimelike))
968966

969967
# Copy values into new array in order to fill missing data
970968
# with mask, without obfuscating location of missing data
971969
# in values array
972-
if rank_t is object and values.dtype != np.object_:
970+
if numeric_object_t is object and values.dtype != np.object_:
973971
masked_vals = values.astype('O')
974972
else:
975973
masked_vals = values.copy()
976974

977-
if rank_t is object:
975+
if numeric_object_t is object:
978976
mask = missing.isnaobj(masked_vals)
979-
elif rank_t is int64_t and is_datetimelike:
977+
elif numeric_object_t is int64_t and is_datetimelike:
980978
mask = (masked_vals == NPY_NAT).astype(np.uint8)
981-
elif rank_t is float64_t:
979+
elif numeric_object_t is float64_t:
982980
mask = np.isnan(masked_vals).astype(np.uint8)
983981
else:
984982
mask = np.zeros(shape=len(masked_vals), dtype=np.uint8)
@@ -990,7 +988,7 @@ def rank_1d(
990988
# will flip the ordering to still end up with lowest rank.
991989
# Symmetric logic applies to `na_option == 'bottom'`
992990
nans_rank_highest = ascending ^ (na_option == 'top')
993-
nan_fill_val = get_rank_nan_fill_val[rank_t](nans_rank_highest)
991+
nan_fill_val = get_rank_nan_fill_val[numeric_object_t](nans_rank_highest)
994992
if nans_rank_highest:
995993
order = [masked_vals, mask]
996994
else:
@@ -1037,7 +1035,7 @@ cdef void rank_sorted_1d(
10371035
int64_t[::1] grp_sizes,
10381036
const intp_t[:] sort_indexer,
10391037
# Can make const with cython3 (https://github.com/cython/cython/issues/3222)
1040-
rank_t[:] masked_vals,
1038+
numeric_object_t[:] masked_vals,
10411039
const uint8_t[:] mask,
10421040
bint check_mask,
10431041
Py_ssize_t N,
@@ -1061,7 +1059,7 @@ cdef void rank_sorted_1d(
10611059
if labels is None.
10621060
sort_indexer : intp_t[:]
10631061
Array of indices which sorts masked_vals
1064-
masked_vals : rank_t[:]
1062+
masked_vals : numeric_object_t[:]
10651063
The values input to rank_1d, with missing values replaced by fill values
10661064
mask : uint8_t[:]
10671065
Array where entries are True if the value is missing, False otherwise.
@@ -1093,7 +1091,7 @@ cdef void rank_sorted_1d(
10931091
# that sorted value for retrieval back from the original
10941092
# values / masked_vals arrays
10951093
# TODO: de-duplicate once cython supports conditional nogil
1096-
if rank_t is object:
1094+
if numeric_object_t is object:
10971095
with gil:
10981096
for i in range(N):
10991097
at_end = i == N - 1
@@ -1301,7 +1299,7 @@ cdef void rank_sorted_1d(
13011299

13021300

13031301
def rank_2d(
1304-
ndarray[rank_t, ndim=2] in_arr,
1302+
ndarray[numeric_object_t, ndim=2] in_arr,
13051303
int axis=0,
13061304
bint is_datetimelike=False,
13071305
ties_method="average",
@@ -1316,13 +1314,13 @@ def rank_2d(
13161314
Py_ssize_t k, n, col
13171315
float64_t[::1, :] out # Column-major so columns are contiguous
13181316
int64_t[::1] grp_sizes
1319-
ndarray[rank_t, ndim=2] values
1320-
rank_t[:, :] masked_vals
1317+
ndarray[numeric_object_t, ndim=2] values
1318+
numeric_object_t[:, :] masked_vals
13211319
intp_t[:, :] sort_indexer
13221320
uint8_t[:, :] mask
13231321
TiebreakEnumType tiebreak
13241322
bint check_mask, keep_na, nans_rank_highest
1325-
rank_t nan_fill_val
1323+
numeric_object_t nan_fill_val
13261324

13271325
tiebreak = tiebreakers[ties_method]
13281326
if tiebreak == TIEBREAK_FIRST:
@@ -1332,24 +1330,25 @@ def rank_2d(
13321330
keep_na = na_option == 'keep'
13331331

13341332
# For cases where a mask is not possible, we can avoid mask checks
1335-
check_mask = not (rank_t is uint64_t or (rank_t is int64_t and not is_datetimelike))
1333+
check_mask = not (numeric_object_t is uint64_t or
1334+
(numeric_object_t is int64_t and not is_datetimelike))
13361335

13371336
if axis == 1:
13381337
values = np.asarray(in_arr).T.copy()
13391338
else:
13401339
values = np.asarray(in_arr).copy()
13411340

1342-
if rank_t is object:
1341+
if numeric_object_t is object:
13431342
if values.dtype != np.object_:
13441343
values = values.astype('O')
13451344

13461345
nans_rank_highest = ascending ^ (na_option == 'top')
13471346
if check_mask:
1348-
nan_fill_val = get_rank_nan_fill_val[rank_t](nans_rank_highest)
1347+
nan_fill_val = get_rank_nan_fill_val[numeric_object_t](nans_rank_highest)
13491348

1350-
if rank_t is object:
1349+
if numeric_object_t is object:
13511350
mask = missing.isnaobj2d(values).view(np.uint8)
1352-
elif rank_t is float64_t:
1351+
elif numeric_object_t is float64_t:
13531352
mask = np.isnan(values).view(np.uint8)
13541353

13551354
# int64 and datetimelike

pandas/_libs/dtypes.pxd

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Common location for shared fused types
3+
"""
4+
5+
from numpy cimport (
6+
float32_t,
7+
float64_t,
8+
int64_t,
9+
uint64_t,
10+
)
11+
12+
ctypedef fused numeric_object_t:
13+
float64_t
14+
float32_t
15+
int64_t
16+
uint64_t
17+
object

pandas/_libs/groupby.pyx

+25-32
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ from pandas._libs.algos import (
4343
take_2d_axis1_float64_float64,
4444
)
4545

46+
from pandas._libs.dtypes cimport numeric_object_t
4647
from pandas._libs.missing cimport checknull
4748

4849

@@ -921,45 +922,37 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
921922
# group_nth, group_last, group_rank
922923
# ----------------------------------------------------------------------
923924

924-
ctypedef fused rank_t:
925-
float64_t
926-
float32_t
927-
int64_t
928-
uint64_t
929-
object
930-
931-
932-
cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
933-
if rank_t is object:
925+
cdef inline bint _treat_as_na(numeric_object_t val, bint is_datetimelike) nogil:
926+
if numeric_object_t is object:
934927
# Should never be used, but we need to avoid the `val != val` below
935928
# or else cython will raise about gil acquisition.
936929
raise NotImplementedError
937930

938-
elif rank_t is int64_t:
931+
elif numeric_object_t is int64_t:
939932
return is_datetimelike and val == NPY_NAT
940-
elif rank_t is uint64_t:
933+
elif numeric_object_t is uint64_t:
941934
# There is no NA value for uint64
942935
return False
943936
else:
944937
return val != val
945938

946939

947940
# GH#31710 use memorviews once cython 0.30 is released so we can
948-
# use `const rank_t[:, :] values`
941+
# use `const numeric_object_t[:, :] values`
949942
@cython.wraparound(False)
950943
@cython.boundscheck(False)
951-
def group_last(rank_t[:, ::1] out,
944+
def group_last(numeric_object_t[:, ::1] out,
952945
int64_t[::1] counts,
953-
ndarray[rank_t, ndim=2] values,
946+
ndarray[numeric_object_t, ndim=2] values,
954947
const intp_t[::1] labels,
955948
Py_ssize_t min_count=-1) -> None:
956949
"""
957950
Only aggregates on axis=0
958951
"""
959952
cdef:
960953
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
961-
rank_t val
962-
ndarray[rank_t, ndim=2] resx
954+
numeric_object_t val
955+
ndarray[numeric_object_t, ndim=2] resx
963956
ndarray[int64_t, ndim=2] nobs
964957
bint runtime_error = False
965958

@@ -970,14 +963,14 @@ def group_last(rank_t[:, ::1] out,
970963

971964
min_count = max(min_count, 1)
972965
nobs = np.zeros((<object>out).shape, dtype=np.int64)
973-
if rank_t is object:
966+
if numeric_object_t is object:
974967
resx = np.empty((<object>out).shape, dtype=object)
975968
else:
976969
resx = np.empty_like(out)
977970

978971
N, K = (<object>values).shape
979972

980-
if rank_t is object:
973+
if numeric_object_t is object:
981974
# TODO: De-duplicate once conditional-nogil is available
982975
for i in range(N):
983976
lab = labels[i]
@@ -1019,9 +1012,9 @@ def group_last(rank_t[:, ::1] out,
10191012
for i in range(ncounts):
10201013
for j in range(K):
10211014
if nobs[i, j] < min_count:
1022-
if rank_t is int64_t:
1015+
if numeric_object_t is int64_t:
10231016
out[i, j] = NPY_NAT
1024-
elif rank_t is uint64_t:
1017+
elif numeric_object_t is uint64_t:
10251018
runtime_error = True
10261019
break
10271020
else:
@@ -1037,12 +1030,12 @@ def group_last(rank_t[:, ::1] out,
10371030

10381031

10391032
# GH#31710 use memorviews once cython 0.30 is released so we can
1040-
# use `const rank_t[:, :] values`
1033+
# use `const numeric_object_t[:, :] values`
10411034
@cython.wraparound(False)
10421035
@cython.boundscheck(False)
1043-
def group_nth(rank_t[:, ::1] out,
1036+
def group_nth(numeric_object_t[:, ::1] out,
10441037
int64_t[::1] counts,
1045-
ndarray[rank_t, ndim=2] values,
1038+
ndarray[numeric_object_t, ndim=2] values,
10461039
const intp_t[::1] labels,
10471040
int64_t min_count=-1,
10481041
int64_t rank=1,
@@ -1052,8 +1045,8 @@ def group_nth(rank_t[:, ::1] out,
10521045
"""
10531046
cdef:
10541047
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
1055-
rank_t val
1056-
ndarray[rank_t, ndim=2] resx
1048+
numeric_object_t val
1049+
ndarray[numeric_object_t, ndim=2] resx
10571050
ndarray[int64_t, ndim=2] nobs
10581051
bint runtime_error = False
10591052

@@ -1064,14 +1057,14 @@ def group_nth(rank_t[:, ::1] out,
10641057

10651058
min_count = max(min_count, 1)
10661059
nobs = np.zeros((<object>out).shape, dtype=np.int64)
1067-
if rank_t is object:
1060+
if numeric_object_t is object:
10681061
resx = np.empty((<object>out).shape, dtype=object)
10691062
else:
10701063
resx = np.empty_like(out)
10711064

10721065
N, K = (<object>values).shape
10731066

1074-
if rank_t is object:
1067+
if numeric_object_t is object:
10751068
# TODO: De-duplicate once conditional-nogil is available
10761069
for i in range(N):
10771070
lab = labels[i]
@@ -1116,9 +1109,9 @@ def group_nth(rank_t[:, ::1] out,
11161109
for i in range(ncounts):
11171110
for j in range(K):
11181111
if nobs[i, j] < min_count:
1119-
if rank_t is int64_t:
1112+
if numeric_object_t is int64_t:
11201113
out[i, j] = NPY_NAT
1121-
elif rank_t is uint64_t:
1114+
elif numeric_object_t is uint64_t:
11221115
runtime_error = True
11231116
break
11241117
else:
@@ -1135,7 +1128,7 @@ def group_nth(rank_t[:, ::1] out,
11351128
@cython.boundscheck(False)
11361129
@cython.wraparound(False)
11371130
def group_rank(float64_t[:, ::1] out,
1138-
ndarray[rank_t, ndim=2] values,
1131+
ndarray[numeric_object_t, ndim=2] values,
11391132
const intp_t[::1] labels,
11401133
int ngroups,
11411134
bint is_datetimelike, str ties_method="average",
@@ -1147,7 +1140,7 @@ def group_rank(float64_t[:, ::1] out,
11471140
----------
11481141
out : np.ndarray[np.float64, ndim=2]
11491142
Values to which this method will write its results.
1150-
values : np.ndarray of rank_t values to be ranked
1143+
values : np.ndarray of numeric_object_t values to be ranked
11511144
labels : np.ndarray[np.intp]
11521145
Array containing unique label for each group, with its ordering
11531146
matching up to the corresponding record in `values`

0 commit comments

Comments
 (0)