Skip to content

Commit bc62e76

Browse files
authored
CLN: rank_1d (pandas-dev#40546)
1 parent f75a620 commit bc62e76

File tree

1 file changed

+83
-66
lines changed

1 file changed

+83
-66
lines changed

pandas/_libs/algos.pyx

+83-66
Original file line numberDiff line numberDiff line change
@@ -995,22 +995,27 @@ def rank_1d(
995995
cdef:
996996
TiebreakEnumType tiebreak
997997
Py_ssize_t i, j, N, grp_start=0, dups=0, sum_ranks=0
998-
Py_ssize_t grp_vals_seen=1, grp_na_count=0, grp_tie_count=0
998+
Py_ssize_t grp_vals_seen=1, grp_na_count=0
999999
ndarray[int64_t, ndim=1] lexsort_indexer
10001000
ndarray[float64_t, ndim=1] grp_sizes, out
10011001
ndarray[rank_t, ndim=1] masked_vals
10021002
ndarray[uint8_t, ndim=1] mask
1003-
bint keep_na, at_end, next_val_diff, check_labels
1003+
bint keep_na, at_end, next_val_diff, check_labels, group_changed
10041004
rank_t nan_fill_val
10051005

10061006
tiebreak = tiebreakers[ties_method]
1007+
if tiebreak == TIEBREAK_FIRST:
1008+
if not ascending:
1009+
tiebreak = TIEBREAK_FIRST_DESCENDING
1010+
10071011
keep_na = na_option == 'keep'
10081012

10091013
N = len(values)
10101014
# TODO Cython 3.0: cast won't be necessary (#2992)
10111015
assert <Py_ssize_t>len(labels) == N
10121016
out = np.empty(N)
10131017
grp_sizes = np.ones(N)
1018+
10141019
# If all 0 labels, can short-circuit later label
10151020
# comparisons
10161021
check_labels = np.any(labels)
@@ -1032,6 +1037,12 @@ def rank_1d(
10321037
else:
10331038
mask = np.zeros(shape=len(masked_vals), dtype=np.uint8)
10341039

1040+
# If `na_option == 'top'`, we want to assign the lowest rank
1041+
# to NaN regardless of ascending/descending. So if ascending,
1042+
# fill with lowest value of type to end up with lowest rank.
1043+
# If descending, fill with highest value since descending
1044+
# will flip the ordering to still end up with lowest rank.
1045+
# Symmetric logic applies to `na_option == 'bottom'`
10351046
if ascending ^ (na_option == 'top'):
10361047
if rank_t is object:
10371048
nan_fill_val = Infinity()
@@ -1074,36 +1085,36 @@ def rank_1d(
10741085
if rank_t is object:
10751086
for i in range(N):
10761087
at_end = i == N - 1
1088+
10771089
# dups and sum_ranks will be incremented each loop where
10781090
# the value / group remains the same, and should be reset
1079-
# when either of those change
1080-
# Used to calculate tiebreakers
1091+
# when either of those change. Used to calculate tiebreakers
10811092
dups += 1
10821093
sum_ranks += i - grp_start + 1
10831094

1095+
next_val_diff = at_end or are_diff(masked_vals[lexsort_indexer[i]],
1096+
masked_vals[lexsort_indexer[i+1]])
1097+
1098+
# We'll need this check later anyway to determine group size, so just
1099+
# compute it here since shortcircuiting won't help
1100+
group_changed = at_end or (check_labels and
1101+
(labels[lexsort_indexer[i]]
1102+
!= labels[lexsort_indexer[i+1]]))
1103+
10841104
# Update out only when there is a transition of values or labels.
10851105
# When a new value or group is encountered, go back #dups steps(
10861106
# the number of occurrence of current value) and assign the ranks
10871107
# based on the starting index of the current group (grp_start)
10881108
# and the current index
1089-
if not at_end:
1090-
next_val_diff = are_diff(masked_vals[lexsort_indexer[i]],
1091-
masked_vals[lexsort_indexer[i+1]])
1092-
else:
1093-
next_val_diff = True
1094-
1095-
if (next_val_diff
1096-
or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+1]])
1097-
or (check_labels
1098-
and (labels[lexsort_indexer[i]]
1099-
!= labels[lexsort_indexer[i+1]]))
1100-
):
1101-
# if keep_na, check for missing values and assign back
1109+
if (next_val_diff or group_changed
1110+
or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+1]])):
1111+
1112+
# If keep_na, check for missing values and assign back
11021113
# to the result where appropriate
11031114
if keep_na and mask[lexsort_indexer[i]]:
1115+
grp_na_count = dups
11041116
for j in range(i - dups + 1, i + 1):
11051117
out[lexsort_indexer[j]] = NaN
1106-
grp_na_count = dups
11071118
elif tiebreak == TIEBREAK_AVERAGE:
11081119
for j in range(i - dups + 1, i + 1):
11091120
out[lexsort_indexer[j]] = sum_ranks / <float64_t>dups
@@ -1113,84 +1124,87 @@ def rank_1d(
11131124
elif tiebreak == TIEBREAK_MAX:
11141125
for j in range(i - dups + 1, i + 1):
11151126
out[lexsort_indexer[j]] = i - grp_start + 1
1127+
1128+
# With n as the previous rank in the group and m as the number
1129+
# of duplicates in this stretch, if TIEBREAK_FIRST and ascending,
1130+
# then rankings should be n + 1, n + 2 ... n + m
11161131
elif tiebreak == TIEBREAK_FIRST:
11171132
for j in range(i - dups + 1, i + 1):
1118-
if ascending:
1119-
out[lexsort_indexer[j]] = j + 1 - grp_start
1120-
else:
1121-
out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
1133+
out[lexsort_indexer[j]] = j + 1 - grp_start
1134+
1135+
# If TIEBREAK_FIRST and descending, the ranking should be
1136+
# n + m, n + (m - 1) ... n + 1. This is equivalent to
1137+
# (i - dups + 1) + (i - j + 1) - grp_start
1138+
elif tiebreak == TIEBREAK_FIRST_DESCENDING:
1139+
for j in range(i - dups + 1, i + 1):
1140+
out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
11221141
elif tiebreak == TIEBREAK_DENSE:
11231142
for j in range(i - dups + 1, i + 1):
11241143
out[lexsort_indexer[j]] = grp_vals_seen
11251144

1126-
# look forward to the next value (using the sorting in _as)
1145+
# Look forward to the next value (using the sorting in lexsort_indexer)
11271146
# if the value does not equal the current value then we need to
11281147
# reset the dups and sum_ranks, knowing that a new value is
1129-
# coming up. the conditional also needs to handle nan equality
1148+
# coming up. The conditional also needs to handle nan equality
11301149
# and the end of iteration
11311150
if next_val_diff or (mask[lexsort_indexer[i]]
11321151
^ mask[lexsort_indexer[i+1]]):
11331152
dups = sum_ranks = 0
11341153
grp_vals_seen += 1
1135-
grp_tie_count += 1
11361154

11371155
# Similar to the previous conditional, check now if we are
11381156
# moving to a new group. If so, keep track of the index where
11391157
# the new group occurs, so the tiebreaker calculations can
1140-
# decrement that from their position. fill in the size of each
1141-
# group encountered (used by pct calculations later). also be
1158+
# decrement that from their position. Fill in the size of each
1159+
# group encountered (used by pct calculations later). Also be
11421160
# sure to reset any of the items helping to calculate dups
1143-
if (at_end or
1144-
(check_labels
1145-
and (labels[lexsort_indexer[i]]
1146-
!= labels[lexsort_indexer[i+1]]))):
1161+
if group_changed:
11471162
if tiebreak != TIEBREAK_DENSE:
11481163
for j in range(grp_start, i + 1):
11491164
grp_sizes[lexsort_indexer[j]] = \
11501165
(i - grp_start + 1 - grp_na_count)
11511166
else:
11521167
for j in range(grp_start, i + 1):
11531168
grp_sizes[lexsort_indexer[j]] = \
1154-
(grp_tie_count - (grp_na_count > 0))
1169+
(grp_vals_seen - 1 - (grp_na_count > 0))
11551170
dups = sum_ranks = 0
11561171
grp_na_count = 0
1157-
grp_tie_count = 0
11581172
grp_start = i + 1
11591173
grp_vals_seen = 1
11601174
else:
11611175
with nogil:
11621176
for i in range(N):
11631177
at_end = i == N - 1
1178+
11641179
# dups and sum_ranks will be incremented each loop where
11651180
# the value / group remains the same, and should be reset
1166-
# when either of those change
1167-
# Used to calculate tiebreakers
1181+
# when either of those change. Used to calculate tiebreakers
11681182
dups += 1
11691183
sum_ranks += i - grp_start + 1
11701184

1185+
next_val_diff = at_end or (masked_vals[lexsort_indexer[i]]
1186+
!= masked_vals[lexsort_indexer[i+1]])
1187+
1188+
# We'll need this check later anyway to determine group size, so just
1189+
# compute it here since shortcircuiting won't help
1190+
group_changed = at_end or (check_labels and
1191+
(labels[lexsort_indexer[i]]
1192+
!= labels[lexsort_indexer[i+1]]))
1193+
11711194
# Update out only when there is a transition of values or labels.
11721195
# When a new value or group is encountered, go back #dups steps(
11731196
# the number of occurrence of current value) and assign the ranks
11741197
# based on the starting index of the current group (grp_start)
11751198
# and the current index
1176-
if not at_end:
1177-
next_val_diff = (masked_vals[lexsort_indexer[i]]
1178-
!= masked_vals[lexsort_indexer[i+1]])
1179-
else:
1180-
next_val_diff = True
1181-
1182-
if (next_val_diff
1183-
or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+1]])
1184-
or (check_labels
1185-
and (labels[lexsort_indexer[i]]
1186-
!= labels[lexsort_indexer[i+1]]))
1187-
):
1188-
# if keep_na, check for missing values and assign back
1199+
if (next_val_diff or group_changed
1200+
or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+1]])):
1201+
1202+
# If keep_na, check for missing values and assign back
11891203
# to the result where appropriate
11901204
if keep_na and mask[lexsort_indexer[i]]:
1205+
grp_na_count = dups
11911206
for j in range(i - dups + 1, i + 1):
11921207
out[lexsort_indexer[j]] = NaN
1193-
grp_na_count = dups
11941208
elif tiebreak == TIEBREAK_AVERAGE:
11951209
for j in range(i - dups + 1, i + 1):
11961210
out[lexsort_indexer[j]] = sum_ranks / <float64_t>dups
@@ -1200,48 +1214,51 @@ def rank_1d(
12001214
elif tiebreak == TIEBREAK_MAX:
12011215
for j in range(i - dups + 1, i + 1):
12021216
out[lexsort_indexer[j]] = i - grp_start + 1
1217+
1218+
# With n as the previous rank in the group and m as the number
1219+
# of duplicates in this stretch, if TIEBREAK_FIRST and ascending,
1220+
# then rankings should be n + 1, n + 2 ... n + m
12031221
elif tiebreak == TIEBREAK_FIRST:
12041222
for j in range(i - dups + 1, i + 1):
1205-
if ascending:
1206-
out[lexsort_indexer[j]] = j + 1 - grp_start
1207-
else:
1208-
out[lexsort_indexer[j]] = \
1209-
(2 * i - j - dups + 2 - grp_start)
1223+
out[lexsort_indexer[j]] = j + 1 - grp_start
1224+
1225+
# If TIEBREAK_FIRST and descending, the ranking should be
1226+
# n + m, n + (m - 1) ... n + 1. This is equivalent to
1227+
# (i - dups + 1) + (i - j + 1) - grp_start
1228+
elif tiebreak == TIEBREAK_FIRST_DESCENDING:
1229+
for j in range(i - dups + 1, i + 1):
1230+
out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
12101231
elif tiebreak == TIEBREAK_DENSE:
12111232
for j in range(i - dups + 1, i + 1):
12121233
out[lexsort_indexer[j]] = grp_vals_seen
12131234

1214-
# look forward to the next value (using the sorting in
1235+
# Look forward to the next value (using the sorting in
12151236
# lexsort_indexer) if the value does not equal the current
1216-
# value then we need to reset the dups and sum_ranks,
1217-
# knowing that a new value is coming up. the conditional
1218-
# also needs to handle nan equality and the end of iteration
1237+
# value then we need to reset the dups and sum_ranks, knowing
1238+
# that a new value is coming up. The conditional also needs
1239+
# to handle nan equality and the end of iteration
12191240
if next_val_diff or (mask[lexsort_indexer[i]]
12201241
^ mask[lexsort_indexer[i+1]]):
12211242
dups = sum_ranks = 0
12221243
grp_vals_seen += 1
1223-
grp_tie_count += 1
12241244

12251245
# Similar to the previous conditional, check now if we are
12261246
# moving to a new group. If so, keep track of the index where
12271247
# the new group occurs, so the tiebreaker calculations can
1228-
# decrement that from their position. fill in the size of each
1229-
# group encountered (used by pct calculations later). also be
1248+
# decrement that from their position. Fill in the size of each
1249+
# group encountered (used by pct calculations later). Also be
12301250
# sure to reset any of the items helping to calculate dups
1231-
if at_end or (check_labels and
1232-
(labels[lexsort_indexer[i]]
1233-
!= labels[lexsort_indexer[i+1]])):
1251+
if group_changed:
12341252
if tiebreak != TIEBREAK_DENSE:
12351253
for j in range(grp_start, i + 1):
12361254
grp_sizes[lexsort_indexer[j]] = \
12371255
(i - grp_start + 1 - grp_na_count)
12381256
else:
12391257
for j in range(grp_start, i + 1):
12401258
grp_sizes[lexsort_indexer[j]] = \
1241-
(grp_tie_count - (grp_na_count > 0))
1259+
(grp_vals_seen - 1 - (grp_na_count > 0))
12421260
dups = sum_ranks = 0
12431261
grp_na_count = 0
1244-
grp_tie_count = 0
12451262
grp_start = i + 1
12461263
grp_vals_seen = 1
12471264

0 commit comments

Comments
 (0)