|
8 | 8 |
|
9 | 9 | import pandas as pd
|
10 | 10 | from pandas.compat import lrange, StringIO
|
11 |
| -from pandas import Series, DataFrame, Timestamp, date_range, MultiIndex |
| 11 | +from pandas import Series, DataFrame, Timestamp, date_range, MultiIndex, Index |
12 | 12 | from pandas.util import testing as tm
|
13 | 13 | from pandas.tests.indexing.common import Base
|
14 | 14 |
|
@@ -711,3 +711,44 @@ def test_identity_slice_returns_new_object(self):
|
711 | 711 |
|
712 | 712 | original_series[:3] = [7, 8, 9]
|
713 | 713 | assert all(sliced_series[:3] == [7, 8, 9])
|
| 714 | + |
| 715 | + @pytest.mark.parametrize( |
| 716 | + 'indexer_type_1', |
| 717 | + (list, tuple, set, slice, np.ndarray, Series, Index)) |
| 718 | + @pytest.mark.parametrize( |
| 719 | + 'indexer_type_2', |
| 720 | + (list, tuple, set, slice, np.ndarray, Series, Index)) |
| 721 | + def test_loc_getitem_nested_indexer(self, indexer_type_1, indexer_type_2): |
| 722 | + # GH #19686 |
| 723 | + # .loc should work with nested indexers which can be |
| 724 | + # any list-like objects (see `pandas.api.types.is_list_like`) or slices |
| 725 | + |
| 726 | + def convert_nested_indexer(indexer_type, keys): |
| 727 | + if indexer_type == np.ndarray: |
| 728 | + return np.array(keys) |
| 729 | + if indexer_type == slice: |
| 730 | + return slice(*keys) |
| 731 | + return indexer_type(keys) |
| 732 | + |
| 733 | + a = [10, 20, 30] |
| 734 | + b = [1, 2, 3] |
| 735 | + index = pd.MultiIndex.from_product([a, b]) |
| 736 | + df = pd.DataFrame( |
| 737 | + np.arange(len(index), dtype='int64'), |
| 738 | + index=index, columns=['Data']) |
| 739 | + |
| 740 | + keys = ([10, 20], [2, 3]) |
| 741 | + types = (indexer_type_1, indexer_type_2) |
| 742 | + |
| 743 | + # check indexers with all the combinations of nested objects |
| 744 | + # of all the valid types |
| 745 | + indexer = tuple( |
| 746 | + convert_nested_indexer(indexer_type, k) |
| 747 | + for indexer_type, k in zip(types, keys)) |
| 748 | + |
| 749 | + result = df.loc[indexer, 'Data'] |
| 750 | + expected = pd.Series( |
| 751 | + [1, 2, 4, 5], name='Data', |
| 752 | + index=pd.MultiIndex.from_product(keys)) |
| 753 | + |
| 754 | + tm.assert_series_equal(result, expected) |
0 commit comments