Skip to content

Commit 45aa0ce

Browse files
WillAydjreback
authored andcommitted
Added cast blacklist for certain transform agg funcs (#19355)
1 parent 9872d67 commit 45aa0ce

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ Groupby/Resample/Rolling
506506
- Fixed regression in :func:`DataFrame.groupby` which would not emit an error when called with a tuple key not in the index (:issue:`18798`)
507507
- Bug in :func:`DataFrame.resample` which silently ignored unsupported (or mistyped) options for ``label``, ``closed`` and ``convention`` (:issue:`19303`)
508508
- Bug in :func:`DataFrame.groupby` where tuples were interpreted as lists of keys rather than as keys (:issue:`17979`, :issue:`18249`)
509+
- Bug in ``transform`` where particular aggregation functions were being incorrectly cast to match the dtype(s) of the grouped data (:issue:`19200`)
509510
-
510511

511512
Sparse

pandas/core/groupby.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@
345345
_cython_transforms = frozenset(['cumprod', 'cumsum', 'shift',
346346
'cummin', 'cummax'])
347347

348+
_cython_cast_blacklist = frozenset(['rank', 'count', 'size'])
349+
348350

349351
class Grouper(object):
350352
"""
@@ -965,6 +967,21 @@ def _try_cast(self, result, obj, numeric_only=False):
965967

966968
return result
967969

970+
def _transform_should_cast(self, func_nm):
971+
"""
972+
Parameters:
973+
-----------
974+
func_nm: str
975+
The name of the aggregation function being performed
976+
977+
Returns:
978+
--------
979+
bool
980+
Whether transform should attempt to cast the result of aggregation
981+
"""
982+
return (self.size().fillna(0) > 0).any() and (func_nm not in
983+
_cython_cast_blacklist)
984+
968985
def _cython_transform(self, how, numeric_only=True):
969986
output = collections.OrderedDict()
970987
for name, obj in self._iterate_slices():
@@ -3333,7 +3350,7 @@ def transform(self, func, *args, **kwargs):
33333350
else:
33343351
# cythonized aggregation and merge
33353352
return self._transform_fast(
3336-
lambda: getattr(self, func)(*args, **kwargs))
3353+
lambda: getattr(self, func)(*args, **kwargs), func)
33373354

33383355
# reg transform
33393356
klass = self._selected_obj.__class__
@@ -3364,7 +3381,7 @@ def transform(self, func, *args, **kwargs):
33643381
result.index = self._selected_obj.index
33653382
return result
33663383

3367-
def _transform_fast(self, func):
3384+
def _transform_fast(self, func, func_nm):
33683385
"""
33693386
fast version of transform, only applicable to
33703387
builtin/cythonizable functions
@@ -3373,7 +3390,7 @@ def _transform_fast(self, func):
33733390
func = getattr(self, func)
33743391

33753392
ids, _, ngroup = self.grouper.group_info
3376-
cast = (self.size().fillna(0) > 0).any()
3393+
cast = self._transform_should_cast(func_nm)
33773394
out = algorithms.take_1d(func().values, ids)
33783395
if cast:
33793396
out = self._try_cast(out, self.obj)
@@ -4127,15 +4144,15 @@ def transform(self, func, *args, **kwargs):
41274144
if not result.columns.equals(obj.columns):
41284145
return self._transform_general(func, *args, **kwargs)
41294146

4130-
return self._transform_fast(result, obj)
4147+
return self._transform_fast(result, obj, func)
41314148

4132-
def _transform_fast(self, result, obj):
4149+
def _transform_fast(self, result, obj, func_nm):
41334150
"""
41344151
Fast transform path for aggregations
41354152
"""
41364153
# if there were groups with no observations (Categorical only?)
41374154
# try casting data to original dtype
4138-
cast = (self.size().fillna(0) > 0).any()
4155+
cast = self._transform_should_cast(func_nm)
41394156

41404157
# for each col, reshape to to size of original frame
41414158
# by take operation

pandas/tests/groupby/test_transform.py

+25
Original file line numberDiff line numberDiff line change
@@ -582,3 +582,28 @@ def test_transform_with_non_scalar_group(self):
582582
'group.*',
583583
df.groupby(axis=1, level=1).transform,
584584
lambda z: z.div(z.sum(axis=1), axis=0))
585+
586+
@pytest.mark.parametrize('cols,exp,comp_func', [
587+
('a', pd.Series([1, 1, 1], name='a'), tm.assert_series_equal),
588+
(['a', 'c'], pd.DataFrame({'a': [1, 1, 1], 'c': [1, 1, 1]}),
589+
tm.assert_frame_equal)
590+
])
591+
@pytest.mark.parametrize('agg_func', [
592+
'count', 'rank', 'size'])
593+
def test_transform_numeric_ret(self, cols, exp, comp_func, agg_func):
594+
if agg_func == 'size' and isinstance(cols, list):
595+
pytest.xfail("'size' transformation not supported with "
596+
"NDFrameGroupy")
597+
598+
# GH 19200
599+
df = pd.DataFrame(
600+
{'a': pd.date_range('2018-01-01', periods=3),
601+
'b': range(3),
602+
'c': range(7, 10)})
603+
604+
result = df.groupby('b')[cols].transform(agg_func)
605+
606+
if agg_func == 'rank':
607+
exp = exp.astype('float')
608+
609+
comp_func(result, exp)

0 commit comments

Comments
 (0)