|
7 | 7 | from pandas import DataFrame, Series
|
8 | 8 | import pandas._testing as tm
|
9 | 9 |
|
| 10 | +m = 50 |
| 11 | +n = 1000 |
| 12 | +cols = ["jim", "joe", "jolie", "joline", "jolia"] |
| 13 | + |
| 14 | +vals = [ |
| 15 | + np.random.randint(0, 10, n), |
| 16 | + np.random.choice(list("abcdefghij"), n), |
| 17 | + np.random.choice(pd.date_range("20141009", periods=10).tolist(), n), |
| 18 | + np.random.choice(list("ZYXWVUTSRQ"), n), |
| 19 | + np.random.randn(n), |
| 20 | +] |
| 21 | +vals = list(map(tuple, zip(*vals))) |
| 22 | + |
| 23 | +# bunch of keys for testing |
| 24 | +keys = [ |
| 25 | + np.random.randint(0, 11, m), |
| 26 | + np.random.choice(list("abcdefghijk"), m), |
| 27 | + np.random.choice(pd.date_range("20141009", periods=11).tolist(), m), |
| 28 | + np.random.choice(list("ZYXWVUTSRQP"), m), |
| 29 | +] |
| 30 | +keys = list(map(tuple, zip(*keys))) |
| 31 | +keys += list(map(lambda t: t[:-1], vals[:: n // m])) |
| 32 | + |
| 33 | + |
| 34 | +# covers both unique index and non-unique index |
| 35 | +df = DataFrame(vals, columns=cols) |
| 36 | +a = pd.concat([df, df]) |
| 37 | +b = df.drop_duplicates(subset=cols[:-1]) |
| 38 | + |
10 | 39 |
|
11 |
| -@pytest.mark.slow |
12 | 40 | @pytest.mark.filterwarnings("ignore::pandas.errors.PerformanceWarning")
|
13 |
| -def test_multiindex_get_loc(): # GH7724, GH2646 |
| 41 | +@pytest.mark.parametrize("lexsort_depth", list(range(5))) |
| 42 | +@pytest.mark.parametrize("key", keys) |
| 43 | +@pytest.mark.parametrize("frame", [a, b]) |
| 44 | +def test_multiindex_get_loc(lexsort_depth, key, frame): |
| 45 | + # GH7724, GH2646 |
14 | 46 |
|
15 | 47 | with warnings.catch_warnings(record=True):
|
16 | 48 |
|
17 | 49 | # test indexing into a multi-index before & past the lexsort depth
|
18 |
| - from numpy.random import choice, randint, randn |
19 |
| - |
20 |
| - cols = ["jim", "joe", "jolie", "joline", "jolia"] |
21 | 50 |
|
22 | 51 | def validate(mi, df, key):
|
23 | 52 | mask = np.ones(len(df)).astype("bool")
|
@@ -51,38 +80,11 @@ def validate(mi, df, key):
|
51 | 80 | else: # multi hit
|
52 | 81 | tm.assert_frame_equal(mi.loc[key[: i + 1]], right)
|
53 | 82 |
|
54 |
| - def loop(mi, df, keys): |
55 |
| - for key in keys: |
56 |
| - validate(mi, df, key) |
57 |
| - |
58 |
| - n, m = 1000, 50 |
59 |
| - |
60 |
| - vals = [ |
61 |
| - randint(0, 10, n), |
62 |
| - choice(list("abcdefghij"), n), |
63 |
| - choice(pd.date_range("20141009", periods=10).tolist(), n), |
64 |
| - choice(list("ZYXWVUTSRQ"), n), |
65 |
| - randn(n), |
66 |
| - ] |
67 |
| - vals = list(map(tuple, zip(*vals))) |
68 |
| - |
69 |
| - # bunch of keys for testing |
70 |
| - keys = [ |
71 |
| - randint(0, 11, m), |
72 |
| - choice(list("abcdefghijk"), m), |
73 |
| - choice(pd.date_range("20141009", periods=11).tolist(), m), |
74 |
| - choice(list("ZYXWVUTSRQP"), m), |
75 |
| - ] |
76 |
| - keys = list(map(tuple, zip(*keys))) |
77 |
| - keys += list(map(lambda t: t[:-1], vals[:: n // m])) |
78 |
| - |
79 |
| - # covers both unique index and non-unique index |
80 |
| - df = DataFrame(vals, columns=cols) |
81 |
| - a, b = pd.concat([df, df]), df.drop_duplicates(subset=cols[:-1]) |
82 |
| - |
83 |
| - for frame in a, b: |
84 |
| - for i in range(5): # lexsort depth |
85 |
| - df = frame.copy() if i == 0 else frame.sort_values(by=cols[:i]) |
86 |
| - mi = df.set_index(cols[:-1]) |
87 |
| - assert not mi.index.lexsort_depth < i |
88 |
| - loop(mi, df, keys) |
| 83 | + if lexsort_depth == 0: |
| 84 | + df = frame.copy() |
| 85 | + else: |
| 86 | + df = frame.sort_values(by=cols[:lexsort_depth]) |
| 87 | + |
| 88 | + mi = df.set_index(cols[:-1]) |
| 89 | + assert not mi.index.lexsort_depth < lexsort_depth |
| 90 | + validate(mi, df, key) |
0 commit comments