@@ -995,22 +995,27 @@ def rank_1d(
995
995
cdef:
996
996
TiebreakEnumType tiebreak
997
997
Py_ssize_t i, j, N, grp_start= 0 , dups= 0 , sum_ranks= 0
998
- Py_ssize_t grp_vals_seen= 1 , grp_na_count= 0 , grp_tie_count = 0
998
+ Py_ssize_t grp_vals_seen= 1 , grp_na_count= 0
999
999
ndarray[int64_t, ndim= 1 ] lexsort_indexer
1000
1000
ndarray[float64_t, ndim= 1 ] grp_sizes, out
1001
1001
ndarray[rank_t, ndim= 1 ] masked_vals
1002
1002
ndarray[uint8_t, ndim= 1 ] mask
1003
- bint keep_na, at_end, next_val_diff, check_labels
1003
+ bint keep_na, at_end, next_val_diff, check_labels, group_changed
1004
1004
rank_t nan_fill_val
1005
1005
1006
1006
tiebreak = tiebreakers[ties_method]
1007
+ if tiebreak == TIEBREAK_FIRST:
1008
+ if not ascending:
1009
+ tiebreak = TIEBREAK_FIRST_DESCENDING
1010
+
1007
1011
keep_na = na_option == ' keep'
1008
1012
1009
1013
N = len (values)
1010
1014
# TODO Cython 3.0: cast won't be necessary (#2992)
1011
1015
assert < Py_ssize_t> len (labels) == N
1012
1016
out = np.empty(N)
1013
1017
grp_sizes = np.ones(N)
1018
+
1014
1019
# If all 0 labels, can short-circuit later label
1015
1020
# comparisons
1016
1021
check_labels = np.any(labels)
@@ -1032,6 +1037,12 @@ def rank_1d(
1032
1037
else :
1033
1038
mask = np.zeros(shape = len (masked_vals), dtype = np.uint8)
1034
1039
1040
+ # If `na_option == 'top'`, we want to assign the lowest rank
1041
+ # to NaN regardless of ascending/descending. So if ascending,
1042
+ # fill with lowest value of type to end up with lowest rank.
1043
+ # If descending, fill with highest value since descending
1044
+ # will flip the ordering to still end up with lowest rank.
1045
+ # Symmetric logic applies to `na_option == 'bottom'`
1035
1046
if ascending ^ (na_option == ' top' ):
1036
1047
if rank_t is object :
1037
1048
nan_fill_val = Infinity()
@@ -1074,36 +1085,36 @@ def rank_1d(
1074
1085
if rank_t is object :
1075
1086
for i in range (N):
1076
1087
at_end = i == N - 1
1088
+
1077
1089
# dups and sum_ranks will be incremented each loop where
1078
1090
# the value / group remains the same, and should be reset
1079
- # when either of those change
1080
- # Used to calculate tiebreakers
1091
+ # when either of those change. Used to calculate tiebreakers
1081
1092
dups += 1
1082
1093
sum_ranks += i - grp_start + 1
1083
1094
1095
+ next_val_diff = at_end or are_diff(masked_vals[lexsort_indexer[i]],
1096
+ masked_vals[lexsort_indexer[i+ 1 ]])
1097
+
1098
+ # We'll need this check later anyway to determine group size, so just
1099
+ # compute it here since shortcircuiting won't help
1100
+ group_changed = at_end or (check_labels and
1101
+ (labels[lexsort_indexer[i]]
1102
+ != labels[lexsort_indexer[i+ 1 ]]))
1103
+
1084
1104
# Update out only when there is a transition of values or labels.
1085
1105
# When a new value or group is encountered, go back #dups steps(
1086
1106
# the number of occurrence of current value) and assign the ranks
1087
1107
# based on the starting index of the current group (grp_start)
1088
1108
# and the current index
1089
- if not at_end:
1090
- next_val_diff = are_diff(masked_vals[lexsort_indexer[i]],
1091
- masked_vals[lexsort_indexer[i+ 1 ]])
1092
- else :
1093
- next_val_diff = True
1094
-
1095
- if (next_val_diff
1096
- or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])
1097
- or (check_labels
1098
- and (labels[lexsort_indexer[i]]
1099
- != labels[lexsort_indexer[i+ 1 ]]))
1100
- ):
1101
- # if keep_na, check for missing values and assign back
1109
+ if (next_val_diff or group_changed
1110
+ or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])):
1111
+
1112
+ # If keep_na, check for missing values and assign back
1102
1113
# to the result where appropriate
1103
1114
if keep_na and mask[lexsort_indexer[i]]:
1115
+ grp_na_count = dups
1104
1116
for j in range (i - dups + 1 , i + 1 ):
1105
1117
out[lexsort_indexer[j]] = NaN
1106
- grp_na_count = dups
1107
1118
elif tiebreak == TIEBREAK_AVERAGE:
1108
1119
for j in range (i - dups + 1 , i + 1 ):
1109
1120
out[lexsort_indexer[j]] = sum_ranks / < float64_t> dups
@@ -1113,84 +1124,87 @@ def rank_1d(
1113
1124
elif tiebreak == TIEBREAK_MAX:
1114
1125
for j in range (i - dups + 1 , i + 1 ):
1115
1126
out[lexsort_indexer[j]] = i - grp_start + 1
1127
+
1128
+ # With n as the previous rank in the group and m as the number
1129
+ # of duplicates in this stretch, if TIEBREAK_FIRST and ascending,
1130
+ # then rankings should be n + 1, n + 2 ... n + m
1116
1131
elif tiebreak == TIEBREAK_FIRST:
1117
1132
for j in range (i - dups + 1 , i + 1 ):
1118
- if ascending:
1119
- out[lexsort_indexer[j]] = j + 1 - grp_start
1120
- else :
1121
- out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
1133
+ out[lexsort_indexer[j]] = j + 1 - grp_start
1134
+
1135
+ # If TIEBREAK_FIRST and descending, the ranking should be
1136
+ # n + m, n + (m - 1) ... n + 1. This is equivalent to
1137
+ # (i - dups + 1) + (i - j + 1) - grp_start
1138
+ elif tiebreak == TIEBREAK_FIRST_DESCENDING:
1139
+ for j in range (i - dups + 1 , i + 1 ):
1140
+ out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
1122
1141
elif tiebreak == TIEBREAK_DENSE:
1123
1142
for j in range (i - dups + 1 , i + 1 ):
1124
1143
out[lexsort_indexer[j]] = grp_vals_seen
1125
1144
1126
- # look forward to the next value (using the sorting in _as )
1145
+ # Look forward to the next value (using the sorting in lexsort_indexer )
1127
1146
# if the value does not equal the current value then we need to
1128
1147
# reset the dups and sum_ranks, knowing that a new value is
1129
- # coming up. the conditional also needs to handle nan equality
1148
+ # coming up. The conditional also needs to handle nan equality
1130
1149
# and the end of iteration
1131
1150
if next_val_diff or (mask[lexsort_indexer[i]]
1132
1151
^ mask[lexsort_indexer[i+ 1 ]]):
1133
1152
dups = sum_ranks = 0
1134
1153
grp_vals_seen += 1
1135
- grp_tie_count += 1
1136
1154
1137
1155
# Similar to the previous conditional, check now if we are
1138
1156
# moving to a new group. If so, keep track of the index where
1139
1157
# the new group occurs, so the tiebreaker calculations can
1140
- # decrement that from their position. fill in the size of each
1141
- # group encountered (used by pct calculations later). also be
1158
+ # decrement that from their position. Fill in the size of each
1159
+ # group encountered (used by pct calculations later). Also be
1142
1160
# sure to reset any of the items helping to calculate dups
1143
- if (at_end or
1144
- (check_labels
1145
- and (labels[lexsort_indexer[i]]
1146
- != labels[lexsort_indexer[i+ 1 ]]))):
1161
+ if group_changed:
1147
1162
if tiebreak != TIEBREAK_DENSE:
1148
1163
for j in range (grp_start, i + 1 ):
1149
1164
grp_sizes[lexsort_indexer[j]] = \
1150
1165
(i - grp_start + 1 - grp_na_count)
1151
1166
else :
1152
1167
for j in range (grp_start, i + 1 ):
1153
1168
grp_sizes[lexsort_indexer[j]] = \
1154
- (grp_tie_count - (grp_na_count > 0 ))
1169
+ (grp_vals_seen - 1 - (grp_na_count > 0 ))
1155
1170
dups = sum_ranks = 0
1156
1171
grp_na_count = 0
1157
- grp_tie_count = 0
1158
1172
grp_start = i + 1
1159
1173
grp_vals_seen = 1
1160
1174
else :
1161
1175
with nogil:
1162
1176
for i in range (N):
1163
1177
at_end = i == N - 1
1178
+
1164
1179
# dups and sum_ranks will be incremented each loop where
1165
1180
# the value / group remains the same, and should be reset
1166
- # when either of those change
1167
- # Used to calculate tiebreakers
1181
+ # when either of those change. Used to calculate tiebreakers
1168
1182
dups += 1
1169
1183
sum_ranks += i - grp_start + 1
1170
1184
1185
+ next_val_diff = at_end or (masked_vals[lexsort_indexer[i]]
1186
+ != masked_vals[lexsort_indexer[i+ 1 ]])
1187
+
1188
+ # We'll need this check later anyway to determine group size, so just
1189
+ # compute it here since shortcircuiting won't help
1190
+ group_changed = at_end or (check_labels and
1191
+ (labels[lexsort_indexer[i]]
1192
+ != labels[lexsort_indexer[i+ 1 ]]))
1193
+
1171
1194
# Update out only when there is a transition of values or labels.
1172
1195
# When a new value or group is encountered, go back #dups steps(
1173
1196
# the number of occurrence of current value) and assign the ranks
1174
1197
# based on the starting index of the current group (grp_start)
1175
1198
# and the current index
1176
- if not at_end:
1177
- next_val_diff = (masked_vals[lexsort_indexer[i]]
1178
- != masked_vals[lexsort_indexer[i+ 1 ]])
1179
- else :
1180
- next_val_diff = True
1181
-
1182
- if (next_val_diff
1183
- or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])
1184
- or (check_labels
1185
- and (labels[lexsort_indexer[i]]
1186
- != labels[lexsort_indexer[i+ 1 ]]))
1187
- ):
1188
- # if keep_na, check for missing values and assign back
1199
+ if (next_val_diff or group_changed
1200
+ or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])):
1201
+
1202
+ # If keep_na, check for missing values and assign back
1189
1203
# to the result where appropriate
1190
1204
if keep_na and mask[lexsort_indexer[i]]:
1205
+ grp_na_count = dups
1191
1206
for j in range (i - dups + 1 , i + 1 ):
1192
1207
out[lexsort_indexer[j]] = NaN
1193
- grp_na_count = dups
1194
1208
elif tiebreak == TIEBREAK_AVERAGE:
1195
1209
for j in range (i - dups + 1 , i + 1 ):
1196
1210
out[lexsort_indexer[j]] = sum_ranks / < float64_t> dups
@@ -1200,48 +1214,51 @@ def rank_1d(
1200
1214
elif tiebreak == TIEBREAK_MAX:
1201
1215
for j in range (i - dups + 1 , i + 1 ):
1202
1216
out[lexsort_indexer[j]] = i - grp_start + 1
1217
+
1218
+ # With n as the previous rank in the group and m as the number
1219
+ # of duplicates in this stretch, if TIEBREAK_FIRST and ascending,
1220
+ # then rankings should be n + 1, n + 2 ... n + m
1203
1221
elif tiebreak == TIEBREAK_FIRST:
1204
1222
for j in range (i - dups + 1 , i + 1 ):
1205
- if ascending:
1206
- out[lexsort_indexer[j]] = j + 1 - grp_start
1207
- else :
1208
- out[lexsort_indexer[j]] = \
1209
- (2 * i - j - dups + 2 - grp_start)
1223
+ out[lexsort_indexer[j]] = j + 1 - grp_start
1224
+
1225
+ # If TIEBREAK_FIRST and descending, the ranking should be
1226
+ # n + m, n + (m - 1) ... n + 1. This is equivalent to
1227
+ # (i - dups + 1) + (i - j + 1) - grp_start
1228
+ elif tiebreak == TIEBREAK_FIRST_DESCENDING:
1229
+ for j in range (i - dups + 1 , i + 1 ):
1230
+ out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
1210
1231
elif tiebreak == TIEBREAK_DENSE:
1211
1232
for j in range (i - dups + 1 , i + 1 ):
1212
1233
out[lexsort_indexer[j]] = grp_vals_seen
1213
1234
1214
- # look forward to the next value (using the sorting in
1235
+ # Look forward to the next value (using the sorting in
1215
1236
# lexsort_indexer) if the value does not equal the current
1216
- # value then we need to reset the dups and sum_ranks,
1217
- # knowing that a new value is coming up. the conditional
1218
- # also needs to handle nan equality and the end of iteration
1237
+ # value then we need to reset the dups and sum_ranks, knowing
1238
+ # that a new value is coming up. The conditional also needs
1239
+ # to handle nan equality and the end of iteration
1219
1240
if next_val_diff or (mask[lexsort_indexer[i]]
1220
1241
^ mask[lexsort_indexer[i+ 1 ]]):
1221
1242
dups = sum_ranks = 0
1222
1243
grp_vals_seen += 1
1223
- grp_tie_count += 1
1224
1244
1225
1245
# Similar to the previous conditional, check now if we are
1226
1246
# moving to a new group. If so, keep track of the index where
1227
1247
# the new group occurs, so the tiebreaker calculations can
1228
- # decrement that from their position. fill in the size of each
1229
- # group encountered (used by pct calculations later). also be
1248
+ # decrement that from their position. Fill in the size of each
1249
+ # group encountered (used by pct calculations later). Also be
1230
1250
# sure to reset any of the items helping to calculate dups
1231
- if at_end or (check_labels and
1232
- (labels[lexsort_indexer[i]]
1233
- != labels[lexsort_indexer[i+ 1 ]])):
1251
+ if group_changed:
1234
1252
if tiebreak != TIEBREAK_DENSE:
1235
1253
for j in range (grp_start, i + 1 ):
1236
1254
grp_sizes[lexsort_indexer[j]] = \
1237
1255
(i - grp_start + 1 - grp_na_count)
1238
1256
else :
1239
1257
for j in range (grp_start, i + 1 ):
1240
1258
grp_sizes[lexsort_indexer[j]] = \
1241
- (grp_tie_count - (grp_na_count > 0 ))
1259
+ (grp_vals_seen - 1 - (grp_na_count > 0 ))
1242
1260
dups = sum_ranks = 0
1243
1261
grp_na_count = 0
1244
- grp_tie_count = 0
1245
1262
grp_start = i + 1
1246
1263
grp_vals_seen = 1
1247
1264
0 commit comments