Skip to content

Commit 8e6af0e

Browse files
committed
BUG: Fix quantile calculation #33200
1 parent 15a27ea commit 8e6af0e

File tree

2 files changed

+41
-42
lines changed

2 files changed

+41
-42
lines changed

pandas/_libs/groupby.pyx

+39-41
Original file line numberDiff line numberDiff line change
@@ -778,51 +778,49 @@ def group_quantile(ndarray[float64_t] out,
778778
if not mask[i]:
779779
non_na_counts[lab] += 1
780780

781-
# Get an index of values sorted by labels and then values
782-
sort_arr = np.arange(len(labels), dtype=np.int64)
783-
mask = labels != -1
784-
order = (np.asarray(values)[mask], labels[mask])
785-
sort_arr[mask] = np.lexsort(order).astype(np.int64, copy=False)
786-
with nogil:
787-
for i in range(ngroups):
788-
# Figure out how many group elements there are
789-
grp_sz = counts[i]
790-
non_na_sz = non_na_counts[i]
781+
if labels.any():
782+
# Get an index of values sorted by labels and then values
783+
labels[labels==-1] = np.max(labels) + 1
784+
order = (values, labels)
785+
sort_arr= np.lexsort(order).astype(np.int64, copy=False)
786+
with nogil:
787+
for i in range(ngroups):
788+
# Figure out how many group elements there are
789+
grp_sz = counts[i]
790+
non_na_sz = non_na_counts[i]
791791

792-
if non_na_sz == 0:
793-
out[i] = NaN
794-
else:
795-
# Calculate where to retrieve the desired value
796-
# Casting to int will intentionally truncate result
797-
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
798-
799-
val = values[sort_arr[idx]]
800-
# If requested quantile falls evenly on a particular index
801-
# then write that index's value out. Otherwise interpolate
802-
q_idx = q * (non_na_sz - 1)
803-
frac = q_idx % 1
804-
805-
if frac == 0.0 or interp == INTERPOLATION_LOWER:
806-
out[i] = val
792+
if non_na_sz == 0:
793+
out[i] = NaN
807794
else:
808-
next_val = values[sort_arr[idx + 1]]
809-
if interp == INTERPOLATION_LINEAR:
810-
out[i] = val + (next_val - val) * frac
811-
elif interp == INTERPOLATION_HIGHER:
812-
out[i] = next_val
813-
elif interp == INTERPOLATION_MIDPOINT:
814-
out[i] = (val + next_val) / 2.0
815-
elif interp == INTERPOLATION_NEAREST:
816-
if frac > .5 or (frac == .5 and q > .5): # Always OK?
795+
# Calculate where to retrieve the desired value
796+
# Casting to int will intentionally truncate result
797+
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
798+
799+
val = values[sort_arr[idx]]
800+
# If requested quantile falls evenly on a particular index
801+
# then write that index's value out. Otherwise interpolate
802+
q_idx = q * (non_na_sz - 1)
803+
frac = q_idx % 1
804+
805+
if frac == 0.0 or interp == INTERPOLATION_LOWER:
806+
out[i] = val
807+
else:
808+
next_val = values[sort_arr[idx + 1]]
809+
if interp == INTERPOLATION_LINEAR:
810+
out[i] = val + (next_val - val) * frac
811+
elif interp == INTERPOLATION_HIGHER:
817812
out[i] = next_val
818-
else:
819-
out[i] = val
813+
elif interp == INTERPOLATION_MIDPOINT:
814+
out[i] = (val + next_val) / 2.0
815+
elif interp == INTERPOLATION_NEAREST:
816+
if frac > .5 or (frac == .5 and q > .5): # Always OK?
817+
out[i] = next_val
818+
else:
819+
out[i] = val
820+
821+
# Increment the index reference in sorted_arr for the next group
822+
grp_start += grp_sz
820823

821-
# Increment the index reference in sorted_arr for the next group
822-
grp_start += grp_sz
823-
print(out)
824-
# out = np.roll(out, -(len(out) - np.sum(counts)))
825-
print(out)
826824

827825
# ----------------------------------------------------------------------
828826
# group_nth, group_last, group_rank

pandas/tests/groupby/test_function.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,8 @@ def test_quantile_missing_group_values_no_segfaults():
15101510
@pytest.mark.parametrize(
15111511
"key, val, expected_key, expected_val",
15121512
[
1513-
([1.0, np.nan, 3.0, np.nan], range(4), [1.0, 3.0], [0.0, 1.0]),
1513+
([1.0, np.nan, 3.0, np.nan], range(4), [1.0, 3.0], [0.0, 2.0]),
1514+
([1.0, np.nan, 2.0, 2.0], range(4), [1.0, 2.0], [0.0, 2.5]),
15141515
(["a", "b", "b", np.nan], range(4), ["a", "b"], [0, 1.5]),
15151516
],
15161517
)

0 commit comments

Comments
 (0)