Skip to content

Commit 537e170

Browse files
committed
Revert "REF: remove no-op casting (pandas-dev#41317)"
This reverts commit e0cc505.
1 parent c069ce8 commit 537e170

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

pandas/core/groupby/groupby.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ class providing the base-class of operations.
6060
doc,
6161
)
6262

63+
from pandas.core.dtypes.cast import maybe_downcast_numeric
6364
from pandas.core.dtypes.common import (
65+
ensure_float,
6466
is_bool_dtype,
6567
is_datetime64_dtype,
6668
is_integer_dtype,
@@ -1273,7 +1275,19 @@ def _python_agg_general(self, func, *args, **kwargs):
12731275
except TypeError:
12741276
continue
12751277

1278+
assert result is not None
12761279
key = base.OutputKey(label=name, position=idx)
1280+
1281+
if self.grouper._filter_empty_groups:
1282+
mask = counts.ravel() > 0
1283+
1284+
# since we are masking, make sure that we have a float object
1285+
values = result
1286+
if is_numeric_dtype(values.dtype):
1287+
values = ensure_float(values)
1288+
1289+
result = maybe_downcast_numeric(values[mask], result.dtype)
1290+
12771291
output[key] = result
12781292

12791293
if not output:
@@ -3072,7 +3086,9 @@ def _reindex_output(
30723086
Object (potentially) re-indexed to include all possible groups.
30733087
"""
30743088
groupings = self.grouper.groupings
3075-
if len(groupings) == 1:
3089+
if groupings is None:
3090+
return output
3091+
elif len(groupings) == 1:
30763092
return output
30773093

30783094
# if we only care about the observed values

pandas/core/groupby/ops.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
needs_i8_conversion,
5757
)
5858
from pandas.core.dtypes.dtypes import ExtensionDtype
59+
from pandas.core.dtypes.generic import ABCCategoricalIndex
5960
from pandas.core.dtypes.missing import (
6061
isna,
6162
maybe_fill,
@@ -88,7 +89,6 @@
8889
grouper,
8990
)
9091
from pandas.core.indexes.api import (
91-
CategoricalIndex,
9292
Index,
9393
MultiIndex,
9494
ensure_index,
@@ -676,6 +676,7 @@ def __init__(
676676
):
677677
assert isinstance(axis, Index), axis
678678

679+
self._filter_empty_groups = self.compressed = len(groupings) != 1
679680
self.axis = axis
680681
self._groupings: list[grouper.Grouping] = list(groupings)
681682
self.sort = sort
@@ -821,7 +822,9 @@ def apply(self, f: F, data: FrameOrSeries, axis: int = 0):
821822
@cache_readonly
822823
def indices(self):
823824
""" dict {group name -> group indices} """
824-
if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
825+
if len(self.groupings) == 1 and isinstance(
826+
self.result_index, ABCCategoricalIndex
827+
):
825828
# This shows unused categories in indices GH#38642
826829
return self.groupings[0].indices
827830
codes_list = [ping.codes for ping in self.groupings]
@@ -910,7 +913,7 @@ def reconstructed_codes(self) -> list[np.ndarray]:
910913

911914
@cache_readonly
912915
def result_index(self) -> Index:
913-
if len(self.groupings) == 1:
916+
if not self.compressed and len(self.groupings) == 1:
914917
return self.groupings[0].result_index.rename(self.names[0])
915918

916919
codes = self.reconstructed_codes
@@ -921,9 +924,7 @@ def result_index(self) -> Index:
921924

922925
@final
923926
def get_group_levels(self) -> list[Index]:
924-
# Note: only called from _insert_inaxis_grouper_inplace, which
925-
# is only called for BaseGrouper, never for BinGrouper
926-
if len(self.groupings) == 1:
927+
if not self.compressed and len(self.groupings) == 1:
927928
return [self.groupings[0].result_index]
928929

929930
name_list = []
@@ -1090,6 +1091,7 @@ def __init__(
10901091
):
10911092
self.bins = ensure_int64(bins)
10921093
self.binlabels = ensure_index(binlabels)
1094+
self._filter_empty_groups = False
10931095
self.mutated = mutated
10941096
self.indexer = indexer
10951097

@@ -1199,9 +1201,10 @@ def names(self) -> list[Hashable]:
11991201

12001202
@property
12011203
def groupings(self) -> list[grouper.Grouping]:
1202-
lev = self.binlabels
1203-
ping = grouper.Grouping(lev, lev, in_axis=False, level=None, name=lev.name)
1204-
return [ping]
1204+
return [
1205+
grouper.Grouping(lvl, lvl, in_axis=False, level=None, name=name)
1206+
for lvl, name in zip(self.levels, self.names)
1207+
]
12051208

12061209
def _aggregate_series_fast(self, obj: Series, func: F) -> np.ndarray:
12071210
# -> np.ndarray[object]

0 commit comments

Comments
 (0)