Skip to content

Commit 94bae0b

Browse files
committed
Patched the template, and added a test for '.shift()'
1 parent fe2f0ec commit 94bae0b

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

pandas/src/algos_groupby_helper.pxi.in

+6
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,12 @@ def group_shift_indexer(int64_t[:] out, int64_t[:] labels,
700700
## reverse iterator if shifting backwards
701701
ii = offset + sign * i
702702
lab = labels[ii]
703+
704+
# Skip null keys
705+
if lab == -1:
706+
out[ii] = -1
707+
continue
708+
703709
label_seen[lab] += 1
704710

705711
idxer_slot = label_seen[lab] % periods

pandas/tests/test_groupby.py

+25
Original file line numberDiff line numberDiff line change
@@ -6560,6 +6560,31 @@ def test_grouping_string_repr(self):
65606560
expected = "Grouping(('A', 'a'))"
65616561
tm.assert_equal(result, expected)
65626562

6563+
def test_group_shift_with_null_key(self):
6564+
# This test is designed to replicate the segfault in issue #13813.
6565+
n_rows = 1200
6566+
6567+
# Generate a moderately large dataframe with occasional missing
6568+
# values in column `B`, and then group by [`A`, `B`]. This should
6569+
# force `-1` in `labels` array of `gr_.grouper.group_info` exactly
6570+
# at those places, where the group-by key is partilly missing.
6571+
df = pd.DataFrame([(i%12, i%3 if i%3 else float("nan"), i)
6572+
for i in range(n_rows)], dtype=float,
6573+
columns=["A", "B", "Z"], index=None)
6574+
gr_ = df.groupby(["A", "B"])
6575+
6576+
# Generate teh expected dataframe
6577+
expected = pd.DataFrame([(i%12, i%3 if i%3 else float("nan"),
6578+
i + 12 if i%3 and i < n_rows - 12 \
6579+
else float("nan"))
6580+
for i in range(n_rows)], dtype=float,
6581+
columns=["A", "B", "Z"], index=None)
6582+
result = gr_.shift(-1)
6583+
6584+
# Check for data grabbed from beyond the acceptable array bounds
6585+
# in case there was no segfault.
6586+
tm.assert_frame_equal(result, expected[["Z"]])
6587+
65636588

65646589
def assert_fp_equal(a, b):
65656590
assert (np.abs(a - b) < 1e-12).all()

0 commit comments

Comments
 (0)