diff --git a/doc/source/whatsnew/v0.22.0.txt b/doc/source/whatsnew/v0.22.0.txt index 712119caae6f2..7125aaed8b4b7 100644 --- a/doc/source/whatsnew/v0.22.0.txt +++ b/doc/source/whatsnew/v0.22.0.txt @@ -99,7 +99,7 @@ Indexing ^^^^^^^^ - Bug in :func:`Series.truncate` which raises ``TypeError`` with a monotonic ``PeriodIndex`` (:issue:`17717`) -- Bug in :func:`DataFrame.groupby` where key as tuple in a ``MultiIndex`` were interpreted as a list of keys (:issue:`17979`) +- Bug in :func:`DataFrame.groupby` where tuples were interpreted as lists of keys rather than as keys (:issue:`17979`) - - diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f1edfe276dfad..8b2a15e6d1666 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5092,14 +5092,15 @@ def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, Parameters ---------- - by : mapping, function, str, or iterable + by : mapping, function, label, or list of labels Used to determine the groups for the groupby. If ``by`` is a function, it's called on each value of the object's index. If a dict or Series is passed, the Series or dict VALUES will be used to determine the groups (the Series' values are first aligned; see ``.align()`` method). If an ndarray is passed, the - values are used as-is determine the groups. A str or list of strs - may be passed to group by the columns in ``self`` + values are used as-is determine the groups. A label or list of + labels may be passed to group by the columns in ``self``. Notice + that a tuple is interpreted a (single) key. axis : int, default 0 level : int, level name, or sequence of such, default None If the axis is a MultiIndex (hierarchical), group by a particular diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 8db75accc84e5..406b8dff40843 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -2704,7 +2704,6 @@ def _get_grouper(obj, key=None, axis=0, level=None, sort=True, """ group_axis = obj._get_axis(axis) - is_axis_multiindex = isinstance(obj._info_axis, MultiIndex) # validate that the passed single level is compatible with the passed # axis of the object @@ -2765,9 +2764,8 @@ def _get_grouper(obj, key=None, axis=0, level=None, sort=True, elif isinstance(key, BaseGrouper): return key, [], obj - # when MultiIndex, allow tuple to be a key - if not isinstance(key, (tuple, list)) or \ - (isinstance(key, tuple) and is_axis_multiindex): + # Everything which is not a list is a key (including tuples): + if not isinstance(key, list): keys = [key] match_axis_length = False else: diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 2f750a7621905..a763dd170674d 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -264,7 +264,7 @@ def test_len(self): df = pd.DataFrame(dict(a=[np.nan] * 3, b=[1, 2, 3])) assert len(df.groupby(('a'))) == 0 assert len(df.groupby(('b'))) == 3 - assert len(df.groupby(('a', 'b'))) == 3 + assert len(df.groupby(['a', 'b'])) == 3 def test_basic_regression(self): # regression diff --git a/pandas/tests/groupby/test_grouping.py b/pandas/tests/groupby/test_grouping.py index 9e6de8749952f..d94c691b18265 100644 --- a/pandas/tests/groupby/test_grouping.py +++ b/pandas/tests/groupby/test_grouping.py @@ -366,13 +366,18 @@ def test_groupby_multiindex_tuple(self): result = df.groupby(('b', 1)).groups tm.assert_dict_equal(expected, result) - df2 = pd.DataFrame([[1, 2, 3, 4], [3, 4, 5, 6], [1, 4, 2, 3]], + df2 = pd.DataFrame(df.values, columns=pd.MultiIndex.from_arrays( [['a', 'b', 'b', 'c'], ['d', 'd', 'e', 'e']])) - df2.groupby([('b', 'd')]).groups - expected = df.groupby([('b', 'd')]).groups - result = df.groupby(('b', 'd')).groups + expected = df2.groupby([('b', 'd')]).groups + result = df.groupby(('b', 1)).groups + tm.assert_dict_equal(expected, result) + + df3 = pd.DataFrame(df.values, + columns=[('a', 'd'), ('b', 'd'), ('b', 'e'), 'c']) + expected = df3.groupby([('b', 'd')]).groups + result = df.groupby(('b', 1)).groups tm.assert_dict_equal(expected, result) @pytest.mark.parametrize('sort', [True, False]) diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index 501fe63137cf4..2a408b85f0ed1 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -202,7 +202,7 @@ def test_nth(self): freq='B') df = DataFrame(1, index=business_dates, columns=['a', 'b']) # get the first, fourth and last two business days for each month - key = (df.index.year, df.index.month) + key = [df.index.year, df.index.month] result = df.groupby(key, as_index=False).nth([0, 3, -2, -1]) expected_dates = pd.to_datetime( ['2014/4/1', '2014/4/4', '2014/4/29', '2014/4/30', '2014/5/1', diff --git a/pandas/tests/groupby/test_value_counts.py b/pandas/tests/groupby/test_value_counts.py index 3d7977c63eeb6..1434656115d18 100644 --- a/pandas/tests/groupby/test_value_counts.py +++ b/pandas/tests/groupby/test_value_counts.py @@ -43,7 +43,7 @@ def seed_df(seed_nans, n, m): df = seed_df(seed_nans, n, m) bins = None, np.arange(0, max(5, df['3rd'].max()) + 1, 2) - keys = '1st', '2nd', ('1st', '2nd') + keys = '1st', '2nd', ['1st', '2nd'] for k, b in product(keys, bins): binned.append((df, k, b, n, m)) ids.append("{}-{}-{}".format(k, n, m))