Skip to content

Commit 045d0c7

Browse files
committed
add back some casting
1 parent b66a1c8 commit 045d0c7

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

pandas/core/groupby.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -2776,8 +2776,11 @@ def _transform_fast(self, func):
27762776
func = getattr(self, func)
27772777

27782778
ids, _, ngroup = self.grouper.group_info
2779-
2779+
counts = self.size().fillna(0).values
2780+
cast = (counts == 0).any()
27802781
out = algos.take_1d(func().values, ids)
2782+
if cast:
2783+
out = self._try_cast(out, self.obj)
27812784
return Series(out, index=self.obj.index, name=self.obj.name)
27822785

27832786
def filter(self, func, dropna=True, *args, **kwargs): # noqa
@@ -3456,11 +3459,21 @@ def transform(self, func, *args, **kwargs):
34563459
if not result.columns.equals(obj.columns):
34573460
return self._transform_general(func, *args, **kwargs)
34583461

3459-
# Fast transform
3462+
# Fast transform path for aggregations
3463+
3464+
# if there were groups with no observations (Categorical only?)
3465+
# try casting data to original dtype
3466+
counts = self.size().fillna(0).values
3467+
cast = (counts == 0).any()
3468+
3469+
# by column (could be by block?) reshape aggregated data to
3470+
# size of original frame by repeating obvservations with take
34603471
ids, _, ngroup = self.grouper.group_info
34613472
out = {}
34623473
for col in result:
34633474
out[col] = algos.take_nd(result[col].values, ids)
3475+
if cast:
3476+
out[col] = self._try_cast(out[col], obj[col])
34643477

34653478
return DataFrame(out, columns=result.columns, index=obj.index)
34663479

pandas/tests/test_categorical.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3043,8 +3043,7 @@ def f(x):
30433043
c = pd.cut(df.a, bins=[-10, 0, 10, 20, 30, 40])
30443044

30453045
result = df.a.groupby(c).transform(sum)
3046-
tm.assert_series_equal(result, df['a'], check_names=False)
3047-
self.assertTrue(result.name is None)
3046+
tm.assert_series_equal(result, df['a'])
30483047

30493048
tm.assert_series_equal(
30503049
df.a.groupby(c).transform(lambda xs: np.sum(xs)), df['a'])

0 commit comments

Comments
 (0)