From b694b0996340fc30b0e847111df59522d63b2219 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Thu, 3 Oct 2019 18:41:37 -0700 Subject: [PATCH 01/12] Removed block code --- pandas/core/groupby/generic.py | 132 ++++----------------------------- 1 file changed, 14 insertions(+), 118 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index e556708dc9283..0fdbeef3b91dc 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -56,7 +56,6 @@ ) from pandas.core.index import Index, MultiIndex, _all_indexes_same import pandas.core.indexes.base as ibase -from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series from pandas.plotting import boxplot_frame_groupby @@ -147,93 +146,6 @@ def _iterate_slices(self): continue yield val, slicer(val) - def _cython_agg_general(self, how, alt=None, numeric_only=True, min_count=-1): - new_items, new_blocks = self._cython_agg_blocks( - how, alt=alt, numeric_only=numeric_only, min_count=min_count - ) - return self._wrap_agged_blocks(new_items, new_blocks) - - _block_agg_axis = 0 - - def _cython_agg_blocks(self, how, alt=None, numeric_only=True, min_count=-1): - # TODO: the actual managing of mgr_locs is a PITA - # here, it should happen via BlockManager.combine - - data, agg_axis = self._get_data_to_aggregate() - - if numeric_only: - data = data.get_numeric_data(copy=False) - - new_blocks = [] - new_items = [] - deleted_items = [] - no_result = object() - for block in data.blocks: - # Avoid inheriting result from earlier in the loop - result = no_result - locs = block.mgr_locs.as_array - try: - result, _ = self.grouper.aggregate( - block.values, how, axis=agg_axis, min_count=min_count - ) - except NotImplementedError: - # generally if we have numeric_only=False - # and non-applicable functions - # try to python agg - - if alt is None: - # we cannot perform the operation - # in an alternate way, exclude the block - deleted_items.append(locs) - continue - - # call our grouper again with only this block - obj = self.obj[data.items[locs]] - s = groupby(obj, self.grouper) - try: - result = s.aggregate(lambda x: alt(x, axis=self.axis)) - except TypeError: - # we may have an exception in trying to aggregate - # continue and exclude the block - deleted_items.append(locs) - continue - finally: - if result is not no_result: - # see if we can cast the block back to the original dtype - result = maybe_downcast_numeric(result, block.dtype) - newb = block.make_block(result) - - new_items.append(locs) - new_blocks.append(newb) - - if len(new_blocks) == 0: - raise DataError("No numeric types to aggregate") - - # reset the locs in the blocks to correspond to our - # current ordering - indexer = np.concatenate(new_items) - new_items = data.items.take(np.sort(indexer)) - - if len(deleted_items): - - # we need to adjust the indexer to account for the - # items we have removed - # really should be done in internals :< - - deleted = np.concatenate(deleted_items) - ai = np.arange(len(data)) - mask = np.zeros(len(data)) - mask[deleted] = 1 - indexer = (ai - mask.cumsum())[indexer] - - offset = 0 - for b in new_blocks: - loc = len(b.mgr_locs) - b.mgr_locs = indexer[offset : (offset + loc)] - offset += loc - - return new_items, new_blocks - def aggregate(self, func, *args, **kwargs): _level = kwargs.pop("_level", None) @@ -1385,7 +1297,6 @@ class DataFrameGroupBy(NDFrameGroupBy): _apply_whitelist = base.dataframe_apply_whitelist - _block_agg_axis = 1 _agg_see_also_doc = dedent( """ @@ -1571,24 +1482,6 @@ def _wrap_aggregated_output(self, output, names=None): def _wrap_transformed_output(self, output, names=None): return DataFrame(output, index=self.obj.index) - def _wrap_agged_blocks(self, items, blocks): - if not self.as_index: - index = np.arange(blocks[0].values.shape[-1]) - mgr = BlockManager(blocks, [items, index]) - result = DataFrame(mgr) - - self._insert_inaxis_grouper_inplace(result) - result = result._consolidate() - else: - index = self.grouper.result_index - mgr = BlockManager(blocks, [items, index]) - result = DataFrame(mgr) - - if self.axis == 1: - result = result.T - - return self._reindex_output(result)._convert(datetime=True) - def _iterate_column_groupbys(self): for i, colname in enumerate(self._selected_obj.columns): yield colname, SeriesGroupBy( @@ -1616,20 +1509,23 @@ def count(self): DataFrame Count of values within each group. """ - data, _ = self._get_data_to_aggregate() - ids, _, ngroups = self.grouper.group_info - mask = ids != -1 + obj = self._selected_obj - val = ( - (mask & ~_isna_ndarraylike(np.atleast_2d(blk.get_values()))) - for blk in data.blocks - ) - loc = (blk.mgr_locs for blk in data.blocks) + def groupby_series(obj, col=None): + return SeriesGroupBy(obj, selection=col, grouper=self.grouper).count() + + if isinstance(obj, Series): + results = groupby_series(obj) + else: + from pandas.core.reshape.concat import concat - counter = partial(lib.count_level_2d, labels=ids, max_bin=ngroups, axis=1) - blk = map(make_block, map(counter, val), loc) + results = [groupby_series(obj[col], col) for col in obj.columns] + results = concat(results, axis=1) + results.columns.names = obj.columns.names - return self._wrap_agged_blocks(data.items, list(blk)) + if not self.as_index: + results.index = ibase.default_index(len(results)) + return results def nunique(self, dropna=True): """ From ee85e5a7ecc2cbd7aef0af7b954d46b83ddbadff Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Thu, 3 Oct 2019 19:34:17 -0700 Subject: [PATCH 02/12] Changed impl; fixed test --- pandas/core/groupby/generic.py | 28 ++++++++++-------------- pandas/tests/groupby/test_categorical.py | 2 +- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 0fdbeef3b91dc..e271399b43785 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1509,23 +1509,19 @@ def count(self): DataFrame Count of values within each group. """ - obj = self._selected_obj - - def groupby_series(obj, col=None): - return SeriesGroupBy(obj, selection=col, grouper=self.grouper).count() - - if isinstance(obj, Series): - results = groupby_series(obj) - else: - from pandas.core.reshape.concat import concat + output = OrderedDict() - results = [groupby_series(obj[col], col) for col in obj.columns] - results = concat(results, axis=1) - results.columns.names = obj.columns.names - - if not self.as_index: - results.index = ibase.default_index(len(results)) - return results + # TODO: dispatch to _cython_agg_general instead of custom looping + # TODO: refactor with series logic + ids, _, ngroups = self.grouper.group_info + for name, obj in self._iterate_slices(): + mask = (ids != -1) & ~isna(obj) + ids = ensure_platform_int(ids) + minlength = ngroups or 0 + out = np.bincount(ids[mask], minlength=minlength) + output[name] = out + + return self._wrap_aggregated_output(output) def nunique(self, dropna=True): """ diff --git a/pandas/tests/groupby/test_categorical.py b/pandas/tests/groupby/test_categorical.py index fcc0aa3b1c015..7303ffe688475 100644 --- a/pandas/tests/groupby/test_categorical.py +++ b/pandas/tests/groupby/test_categorical.py @@ -300,7 +300,7 @@ def test_observed(observed): exp_index = CategoricalIndex( list("ab"), name="cat", categories=list("abc"), ordered=True ) - expected = DataFrame({"ints": [1.5, 1.5], "val": [20.0, 30]}, index=exp_index) + expected = DataFrame({"ints": [1.5, 1.5], "val": [20, 30]}, index=exp_index) if not observed: index = CategoricalIndex( list("abc"), name="cat", categories=list("abc"), ordered=True From 32d1b6b5a00f490b3e2780a2e7a7a1c8ffb3f24b Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Sun, 6 Oct 2019 18:35:15 -0700 Subject: [PATCH 03/12] More cleanup --- pandas/core/groupby/generic.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index e271399b43785..87b9ccfb77329 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1436,13 +1436,6 @@ def _wrap_generic_output(self, result, obj): else: return DataFrame(result, index=obj.index, columns=result_index) - def _get_data_to_aggregate(self): - obj = self._obj_with_exclusions - if self.axis == 1: - return obj.T._data, 1 - else: - return obj._data, 1 - def _insert_inaxis_grouper_inplace(self, result): # zip in reverse so we can always insert at loc 0 izip = zip( From 61153234101fa40fab6b5894b45ff98466d68478 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 10:39:15 -0700 Subject: [PATCH 04/12] More robust _wrap_aggregated_output, test fix for median --- pandas/core/groupby/generic.py | 14 ++++++++------ pandas/tests/groupby/test_function.py | 8 +++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 87b9ccfb77329..89c0abad1eb0c 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1454,18 +1454,20 @@ def _insert_inaxis_grouper_inplace(self, result): result.insert(0, name, lev) def _wrap_aggregated_output(self, output, names=None): - agg_axis = 0 if self.axis == 1 else 1 - agg_labels = self._obj_with_exclusions._get_axis(agg_axis) - - output_keys = self._decide_output_index(output, agg_labels) + if isinstance(output, dict): + result = DataFrame(output) + else: + agg_axis = 0 if self.axis == 1 else 1 + agg_labels = self._obj_with_exclusions._get_axis(agg_axis) + output_keys = self._decide_output_index(output, agg_labels) + result = DataFrame(output, columns=output_keys) if not self.as_index: - result = DataFrame(output, columns=output_keys) self._insert_inaxis_grouper_inplace(result) result = result._consolidate() else: index = self.grouper.result_index - result = DataFrame(output, index=index, columns=output_keys) + result.index = index if self.axis == 1: result = result.T diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index afb22a732691c..408c8beadb509 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -179,7 +179,13 @@ def test_arg_passthru(): tm.assert_index_equal(result.columns, expected_columns_numeric) result = f(numeric_only=False) - tm.assert_frame_equal(result.reindex_like(expected), expected) + + # TODO: median isn't implemented for DTI but was working blockwise before? + if attr == "median": + new_expected = expected.drop(columns=["datetime", "datetimetz"]) + tm.assert_frame_equal(result, new_expected) + else: + tm.assert_frame_equal(result.reindex_like(expected), expected) # TODO: min, max *should* handle # categorical (ordered) dtype From 2324c0cc08046a2f4298d7ed8ed236ac9bd6e82a Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 11:01:45 -0700 Subject: [PATCH 05/12] Fixed issue with duplicate column names in count --- pandas/core/groupby/generic.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 89c0abad1eb0c..f801890ac8d04 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1454,20 +1454,23 @@ def _insert_inaxis_grouper_inplace(self, result): result.insert(0, name, lev) def _wrap_aggregated_output(self, output, names=None): + index = self.grouper.result_index + if isinstance(output, dict): - result = DataFrame(output) + result = DataFrame(output, index=index) else: agg_axis = 0 if self.axis == 1 else 1 agg_labels = self._obj_with_exclusions._get_axis(agg_axis) - output_keys = self._decide_output_index(output, agg_labels) + output_keys = self._decide_output_index(output, index=index, columns=agg_labels) result = DataFrame(output, columns=output_keys) + if names: + result.columns = names + if not self.as_index: self._insert_inaxis_grouper_inplace(result) result = result._consolidate() - else: - index = self.grouper.result_index - result.index = index + result = result.reset_index(drop=True) if self.axis == 1: result = result.T @@ -1505,18 +1508,20 @@ def count(self): Count of values within each group. """ output = OrderedDict() + names = [] # TODO: dispatch to _cython_agg_general instead of custom looping # TODO: refactor with series logic ids, _, ngroups = self.grouper.group_info - for name, obj in self._iterate_slices(): + for index, (name, obj) in enumerate(self._obj_with_exclusions.items()): mask = (ids != -1) & ~isna(obj) ids = ensure_platform_int(ids) minlength = ngroups or 0 out = np.bincount(ids[mask], minlength=minlength) - output[name] = out + output[index] = out + names.append(name) - return self._wrap_aggregated_output(output) + return self._wrap_aggregated_output(output, names=names) def nunique(self, dropna=True): """ From 5de118f8751b6a7cc710ee7f704c292e9f08e20b Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 11:08:52 -0700 Subject: [PATCH 06/12] Fixed issue with count over axis=1 --- pandas/core/groupby/generic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index f801890ac8d04..276cb1cd3ba47 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1513,7 +1513,13 @@ def count(self): # TODO: dispatch to _cython_agg_general instead of custom looping # TODO: refactor with series logic ids, _, ngroups = self.grouper.group_info - for index, (name, obj) in enumerate(self._obj_with_exclusions.items()): + + if self.axis == 0: + iter_obj = self._obj_with_exclusions + else: + iter_obj = self._obj_with_exclusions.T + + for index, (name, obj) in enumerate(iter_obj.items()): mask = (ids != -1) & ~isna(obj) ids = ensure_platform_int(ids) minlength = ngroups or 0 From e6b5fd1e140d7b53ae4283d1cc1e702670e9a952 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 11:21:45 -0700 Subject: [PATCH 07/12] Fixed issue with transform output --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e93ce3ce93164..e7f29e99af44f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -692,7 +692,7 @@ def __iter__(self): Generator yielding sequence of (name, subsetted object) for each group """ - return self.grouper.get_iterator(self.obj, axis=self.axis) + return self.grouper.get_iterator(self._selected_obj, axis=self.axis) @Appender( _apply_docs["template"].format( From 44a0c6ab472caa8dfd0204bb75aea7d8d5d37893 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 11:42:17 -0700 Subject: [PATCH 08/12] Blackify --- pandas/core/groupby/generic.py | 5 +++-- pandas/tests/groupby/test_function.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 276cb1cd3ba47..cb9d6e602044e 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1297,7 +1297,6 @@ class DataFrameGroupBy(NDFrameGroupBy): _apply_whitelist = base.dataframe_apply_whitelist - _agg_see_also_doc = dedent( """ See Also @@ -1461,7 +1460,9 @@ def _wrap_aggregated_output(self, output, names=None): else: agg_axis = 0 if self.axis == 1 else 1 agg_labels = self._obj_with_exclusions._get_axis(agg_axis) - output_keys = self._decide_output_index(output, index=index, columns=agg_labels) + output_keys = self._decide_output_index( + output, index=index, columns=agg_labels + ) result = DataFrame(output, columns=output_keys) if names: diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index 408c8beadb509..ce65eb20eb853 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -180,7 +180,7 @@ def test_arg_passthru(): result = f(numeric_only=False) - # TODO: median isn't implemented for DTI but was working blockwise before? + # TODO: median isn't implemented for DTI but was working blockwise before? if attr == "median": new_expected = expected.drop(columns=["datetime", "datetimetz"]) tm.assert_frame_equal(result, new_expected) From 9efdab656064d2a7c13147dfabcd69bc64e9d805 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 18:07:20 -0700 Subject: [PATCH 09/12] Fixed issue with index names not carrying over --- pandas/core/groupby/generic.py | 78 +++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index cb9d6e602044e..cdab754a9d043 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -11,7 +11,19 @@ from functools import partial from textwrap import dedent import typing -from typing import Any, Callable, FrozenSet, Sequence, Type, Union +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Hashable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) import warnings import numpy as np @@ -267,18 +279,39 @@ def _aggregate_item_by_item(self, func, *args, **kwargs): return DataFrame(result, columns=result_columns) - def _decide_output_index(self, output, labels): - if len(output) == len(labels): - output_keys = labels + def _decide_output_index( + self, + output: Dict[int, np.ndarray], + labels: Index, + col_labels: Optional[List[Union[Hashable, Tuple[Hashable, ...]]]] = None, + ) -> Index: + """ + Determine axis labels to use while wrapping aggregated values. + + Parameters + ---------- + output : dict of ndarrays + Results of aggregating by-column. Column names should be integer position + labels : Index + Existing labels of selected object. Used to determine resulting shape and name(s) + col_labels : list, optional + The ultimate column labels for the reshaped object. Each entry in this list + should correspond to a key value in output. Must be valid column labels and tuples + are contained within should map to a MultiIndex + + Returns + ------- + Index or MultiIndex + """ + if col_labels: + keys = col_labels else: - output_keys = sorted(output) - try: - output_keys.sort() - except TypeError: - pass + keys = output.keys() - if isinstance(labels, MultiIndex): - output_keys = MultiIndex.from_tuples(output_keys, names=labels.names) + if isinstance(labels, Index): + output_keys = Index(keys, name=labels.name) + elif isinstance(labels, MultiIndex): + output_keys = MultiIndex.from_tuples(keys, names=labels.names) return output_keys @@ -1452,21 +1485,18 @@ def _insert_inaxis_grouper_inplace(self, result): if in_axis: result.insert(0, name, lev) - def _wrap_aggregated_output(self, output, names=None): + def _wrap_aggregated_output( + self, + output: Dict, + names: Optional[List[Union[Hashable, Tuple[Hashable, ...]]]] = None, + ) -> DataFrame: index = self.grouper.result_index + result = DataFrame(output, index) - if isinstance(output, dict): - result = DataFrame(output, index=index) - else: - agg_axis = 0 if self.axis == 1 else 1 - agg_labels = self._obj_with_exclusions._get_axis(agg_axis) - output_keys = self._decide_output_index( - output, index=index, columns=agg_labels - ) - result = DataFrame(output, columns=output_keys) - - if names: - result.columns = names + agg_axis = 0 if self.axis == 1 else 1 + agg_labels = self._obj_with_exclusions._get_axis(agg_axis) + output_keys = self._decide_output_index(output, agg_labels, names) + result.columns = output_keys if not self.as_index: self._insert_inaxis_grouper_inplace(result) From 654ccf220230eb47d69c5c29b8f558ee0f4281e5 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 18:09:09 -0700 Subject: [PATCH 10/12] flake8 fixup --- pandas/core/groupby/generic.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index cdab754a9d043..2ecb27b82423a 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -28,14 +28,13 @@ import numpy as np -from pandas._libs import Timestamp, lib +from pandas._libs import Timestamp from pandas.compat import PY36 from pandas.errors import AbstractMethodError from pandas.util._decorators import Appender, Substitution from pandas.core.dtypes.cast import ( maybe_convert_objects, - maybe_downcast_numeric, maybe_downcast_to_dtype, ) from pandas.core.dtypes.common import ( @@ -51,11 +50,11 @@ is_object_dtype, is_scalar, ) -from pandas.core.dtypes.missing import _isna_ndarraylike, isna, notna +from pandas.core.dtypes.missing import isna, notna from pandas._typing import FrameOrSeries import pandas.core.algorithms as algorithms -from pandas.core.base import DataError, SpecificationError +from pandas.core.base import SpecificationError import pandas.core.common as com from pandas.core.frame import DataFrame from pandas.core.generic import ABCDataFrame, ABCSeries, NDFrame, _shared_docs @@ -64,7 +63,6 @@ GroupBy, _apply_docs, _transform_template, - groupby, ) from pandas.core.index import Index, MultiIndex, _all_indexes_same import pandas.core.indexes.base as ibase @@ -291,13 +289,16 @@ def _decide_output_index( Parameters ---------- output : dict of ndarrays - Results of aggregating by-column. Column names should be integer position + Results of aggregating by-column. Column names should be integer + position. labels : Index - Existing labels of selected object. Used to determine resulting shape and name(s) + Existing labels of selected object. Used to determine resulting + shape and name(s). col_labels : list, optional - The ultimate column labels for the reshaped object. Each entry in this list - should correspond to a key value in output. Must be valid column labels and tuples - are contained within should map to a MultiIndex + The ultimate column labels for the reshaped object. Each entry in + this list should correspond to a key value in output. Must be valid + column labels and tuples are contained within should map to a + MultiIndex. Returns ------- From 2d3c5ddb57eb6527add3ef00a18337f23ae936da Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 18:12:21 -0700 Subject: [PATCH 11/12] typing fixup --- pandas/core/groupby/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 2ecb27b82423a..490334287816f 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -307,7 +307,7 @@ def _decide_output_index( if col_labels: keys = col_labels else: - keys = output.keys() + keys = list(output.keys()) if isinstance(labels, Index): output_keys = Index(keys, name=labels.name) From 6c521f201e5bf3f5d3d0fa598553082c19d8014b Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 7 Oct 2019 18:24:47 -0700 Subject: [PATCH 12/12] Fixed issue with multiindex names --- pandas/core/groupby/generic.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 490334287816f..7ff15ec2aae9a 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -33,10 +33,7 @@ from pandas.errors import AbstractMethodError from pandas.util._decorators import Appender, Substitution -from pandas.core.dtypes.cast import ( - maybe_convert_objects, - maybe_downcast_to_dtype, -) +from pandas.core.dtypes.cast import maybe_convert_objects, maybe_downcast_to_dtype from pandas.core.dtypes.common import ( ensure_int64, ensure_platform_int, @@ -59,11 +56,7 @@ from pandas.core.frame import DataFrame from pandas.core.generic import ABCDataFrame, ABCSeries, NDFrame, _shared_docs from pandas.core.groupby import base -from pandas.core.groupby.groupby import ( - GroupBy, - _apply_docs, - _transform_template, -) +from pandas.core.groupby.groupby import GroupBy, _apply_docs, _transform_template from pandas.core.index import Index, MultiIndex, _all_indexes_same import pandas.core.indexes.base as ibase from pandas.core.series import Series @@ -279,7 +272,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs): def _decide_output_index( self, - output: Dict[int, np.ndarray], + output: Dict, labels: Index, col_labels: Optional[List[Union[Hashable, Tuple[Hashable, ...]]]] = None, ) -> Index: @@ -289,8 +282,7 @@ def _decide_output_index( Parameters ---------- output : dict of ndarrays - Results of aggregating by-column. Column names should be integer - position. + Results of aggregating by-column. labels : Index Existing labels of selected object. Used to determine resulting shape and name(s). @@ -303,16 +295,23 @@ def _decide_output_index( Returns ------- Index or MultiIndex + + Notes + ----- + Ideally output should always have integers as a key and the col_labels + should be provided separately, but as of writing this is not the case. + When output is not using integers there is a risk of duplicate column + labels not be handled correctly. """ if col_labels: keys = col_labels else: keys = list(output.keys()) - if isinstance(labels, Index): - output_keys = Index(keys, name=labels.name) - elif isinstance(labels, MultiIndex): + if isinstance(labels, MultiIndex): output_keys = MultiIndex.from_tuples(keys, names=labels.names) + else: + output_keys = Index(keys, name=labels.name) return output_keys @@ -1488,7 +1487,7 @@ def _insert_inaxis_grouper_inplace(self, result): def _wrap_aggregated_output( self, - output: Dict, + output: Dict[int, np.ndarray], names: Optional[List[Union[Hashable, Tuple[Hashable, ...]]]] = None, ) -> DataFrame: index = self.grouper.result_index