Skip to content

Commit 045bcb5

Browse files
committed
BUG: make sure that we are passing thru kwargs to groupby
BUG: allow timedelta64 to work in groupby with numeric_only=False closes #5724
1 parent 6eb705f commit 045bcb5

File tree

4 files changed

+171
-27
lines changed

4 files changed

+171
-27
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ Bug Fixes
323323

324324

325325

326+
- Bug in groupby operations with timedelta64 when passing ``numeric_only=False`` (:issue:`5724`)
326327

327328

328329
- Bug in ``DataFrame.to_html`` with ``index=False`` and ``max_rows`` raising in ``IndexError`` (:issue:`14998`)

pandas/compat/numpy/function.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,18 @@ def validate_expanding_func(name, args, kwargs):
306306
raise UnsupportedFunctionCall(msg)
307307

308308

309-
def validate_groupby_func(name, args, kwargs):
309+
def validate_groupby_func(name, args, kwargs, allowed=None):
310310
"""
311-
'args' and 'kwargs' should be empty because all of
311+
'args' and 'kwargs' should be empty, except for allowed
312+
kwargs because all of
312313
their necessary parameters are explicitly listed in
313314
the function signature
314315
"""
316+
if allowed is None:
317+
allowed = []
318+
319+
kwargs = set(kwargs) - set(allowed)
320+
315321
if len(args) + len(kwargs) > 0:
316322
raise UnsupportedFunctionCall((
317323
"numpy operations are not valid "

pandas/core/groupby.py

+81-25
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_categorical_dtype,
2020
is_datetimelike,
2121
is_datetime_or_timedelta_dtype,
22+
is_datetime64_any_dtype,
2223
is_bool, is_integer_dtype,
2324
is_complex_dtype,
2425
is_bool_dtype,
@@ -108,10 +109,12 @@ def _groupby_function(name, alias, npfunc, numeric_only=True,
108109
@Substitution(name='groupby', f=name)
109110
@Appender(_doc_template)
110111
@Appender(_local_template)
111-
def f(self):
112+
def f(self, **kwargs):
113+
if 'numeric_only' not in kwargs:
114+
kwargs['numeric_only'] = numeric_only
112115
self._set_group_selection()
113116
try:
114-
return self._cython_agg_general(alias, numeric_only=numeric_only)
117+
return self._cython_agg_general(alias, alt=npfunc, **kwargs)
115118
except AssertionError as e:
116119
raise SpecificationError(str(e))
117120
except Exception:
@@ -126,7 +129,9 @@ def f(self):
126129

127130

128131
def _first_compat(x, axis=0):
132+
129133
def _first(x):
134+
130135
x = np.asarray(x)
131136
x = x[notnull(x)]
132137
if len(x) == 0:
@@ -141,6 +146,7 @@ def _first(x):
141146

142147
def _last_compat(x, axis=0):
143148
def _last(x):
149+
144150
x = np.asarray(x)
145151
x = x[notnull(x)]
146152
if len(x) == 0:
@@ -782,6 +788,8 @@ def _cython_transform(self, how, numeric_only=True):
782788

783789
try:
784790
result, names = self.grouper.transform(obj.values, how)
791+
except NotImplementedError:
792+
continue
785793
except AssertionError as e:
786794
raise GroupByError(str(e))
787795
output[name] = self._try_cast(result, obj)
@@ -791,7 +799,7 @@ def _cython_transform(self, how, numeric_only=True):
791799

792800
return self._wrap_transformed_output(output, names)
793801

794-
def _cython_agg_general(self, how, numeric_only=True):
802+
def _cython_agg_general(self, how, alt=None, numeric_only=True):
795803
output = {}
796804
for name, obj in self._iterate_slices():
797805
is_numeric = is_numeric_dtype(obj.dtype)
@@ -1014,26 +1022,26 @@ def mean(self, *args, **kwargs):
10141022
10151023
For multiple groupings, the result index will be a MultiIndex
10161024
"""
1017-
nv.validate_groupby_func('mean', args, kwargs)
1025+
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
10181026
try:
1019-
return self._cython_agg_general('mean')
1027+
return self._cython_agg_general('mean', **kwargs)
10201028
except GroupByError:
10211029
raise
10221030
except Exception: # pragma: no cover
10231031
self._set_group_selection()
1024-
f = lambda x: x.mean(axis=self.axis)
1032+
f = lambda x: x.mean(axis=self.axis, **kwargs)
10251033
return self._python_agg_general(f)
10261034

10271035
@Substitution(name='groupby')
10281036
@Appender(_doc_template)
1029-
def median(self):
1037+
def median(self, **kwargs):
10301038
"""
10311039
Compute median of groups, excluding missing values
10321040
10331041
For multiple groupings, the result index will be a MultiIndex
10341042
"""
10351043
try:
1036-
return self._cython_agg_general('median')
1044+
return self._cython_agg_general('median', **kwargs)
10371045
except GroupByError:
10381046
raise
10391047
except Exception: # pragma: no cover
@@ -1043,7 +1051,7 @@ def median(self):
10431051
def f(x):
10441052
if isinstance(x, np.ndarray):
10451053
x = Series(x)
1046-
return x.median(axis=self.axis)
1054+
return x.median(axis=self.axis, **kwargs)
10471055
return self._python_agg_general(f)
10481056

10491057
@Substitution(name='groupby')
@@ -1062,7 +1070,7 @@ def std(self, ddof=1, *args, **kwargs):
10621070

10631071
# TODO: implement at Cython level?
10641072
nv.validate_groupby_func('std', args, kwargs)
1065-
return np.sqrt(self.var(ddof=ddof))
1073+
return np.sqrt(self.var(ddof=ddof, **kwargs))
10661074

10671075
@Substitution(name='groupby')
10681076
@Appender(_doc_template)
@@ -1079,10 +1087,10 @@ def var(self, ddof=1, *args, **kwargs):
10791087
"""
10801088
nv.validate_groupby_func('var', args, kwargs)
10811089
if ddof == 1:
1082-
return self._cython_agg_general('var')
1090+
return self._cython_agg_general('var', **kwargs)
10831091
else:
10841092
self._set_group_selection()
1085-
f = lambda x: x.var(ddof=ddof)
1093+
f = lambda x: x.var(ddof=ddof, **kwargs)
10861094
return self._python_agg_general(f)
10871095

10881096
@Substitution(name='groupby')
@@ -1399,21 +1407,21 @@ def cumcount(self, ascending=True):
13991407
@Appender(_doc_template)
14001408
def cumprod(self, axis=0, *args, **kwargs):
14011409
"""Cumulative product for each group"""
1402-
nv.validate_groupby_func('cumprod', args, kwargs)
1410+
nv.validate_groupby_func('cumprod', args, kwargs, ['numeric_only'])
14031411
if axis != 0:
1404-
return self.apply(lambda x: x.cumprod(axis=axis))
1412+
return self.apply(lambda x: x.cumprod(axis=axis, **kwargs))
14051413

1406-
return self._cython_transform('cumprod')
1414+
return self._cython_transform('cumprod', **kwargs)
14071415

14081416
@Substitution(name='groupby')
14091417
@Appender(_doc_template)
14101418
def cumsum(self, axis=0, *args, **kwargs):
14111419
"""Cumulative sum for each group"""
1412-
nv.validate_groupby_func('cumsum', args, kwargs)
1420+
nv.validate_groupby_func('cumsum', args, kwargs, ['numeric_only'])
14131421
if axis != 0:
1414-
return self.apply(lambda x: x.cumsum(axis=axis))
1422+
return self.apply(lambda x: x.cumsum(axis=axis, **kwargs))
14151423

1416-
return self._cython_transform('cumsum')
1424+
return self._cython_transform('cumsum', **kwargs)
14171425

14181426
@Substitution(name='groupby')
14191427
@Appender(_doc_template)
@@ -1807,6 +1815,28 @@ def wrapper(*args, **kwargs):
18071815
def _cython_operation(self, kind, values, how, axis):
18081816
assert kind in ['transform', 'aggregate']
18091817

1818+
# can we do this operation with our cython functions
1819+
# if not raise NotImplementedError
1820+
1821+
# we raise NotImplemented if this is an invalid operation
1822+
# entirely, e.g. adding datetimes
1823+
1824+
# categoricals are only 1d, so we
1825+
# are not setup for dim transforming
1826+
if is_categorical_dtype(values):
1827+
raise NotImplementedError(
1828+
"categoricals are not support in cython ops ATM")
1829+
elif is_datetime64_any_dtype(values):
1830+
if how in ['add', 'prod', 'cumsum', 'cumprod']:
1831+
raise NotImplementedError(
1832+
"datetime64 type does not support {} "
1833+
"operations".format(how))
1834+
elif is_timedelta64_dtype(values):
1835+
if how in ['prod', 'cumprod']:
1836+
raise NotImplementedError(
1837+
"timedelta64 type does not support {} "
1838+
"operations".format(how))
1839+
18101840
arity = self._cython_arity.get(how, 1)
18111841

18121842
vdim = values.ndim
@@ -3134,9 +3164,9 @@ def _iterate_slices(self):
31343164
continue
31353165
yield val, slicer(val)
31363166

3137-
def _cython_agg_general(self, how, numeric_only=True):
3167+
def _cython_agg_general(self, how, alt=None, numeric_only=True):
31383168
new_items, new_blocks = self._cython_agg_blocks(
3139-
how, numeric_only=numeric_only)
3169+
how, alt=alt, numeric_only=numeric_only)
31403170
return self._wrap_agged_blocks(new_items, new_blocks)
31413171

31423172
def _wrap_agged_blocks(self, items, blocks):
@@ -3162,29 +3192,55 @@ def _wrap_agged_blocks(self, items, blocks):
31623192

31633193
_block_agg_axis = 0
31643194

3165-
def _cython_agg_blocks(self, how, numeric_only=True):
3195+
def _cython_agg_blocks(self, how, alt=None, numeric_only=True):
3196+
# TODO: the actual managing of mgr_locs is a PITA
3197+
# here, it should happen via BlockManager.combine
3198+
31663199
data, agg_axis = self._get_data_to_aggregate()
31673200

31683201
new_blocks = []
31693202

31703203
if numeric_only:
31713204
data = data.get_numeric_data(copy=False)
31723205

3206+
offset = 0
3207+
new_items = []
31733208
for block in data.blocks:
31743209

3175-
result, _ = self.grouper.aggregate(
3176-
block.values, how, axis=agg_axis)
3210+
locs = block.mgr_locs.as_array
3211+
try:
3212+
result, _ = self.grouper.aggregate(
3213+
block.values, how, axis=agg_axis)
3214+
except NotImplementedError:
3215+
# generally if we have numeric_only=False
3216+
# and non-applicable functions
3217+
# try to python agg
3218+
3219+
if alt is None:
3220+
# we cannot perform the operation
3221+
# in an alternate way, exclude the block
3222+
continue
3223+
3224+
# call our grouper again with only this block
3225+
obj = self.obj.iloc[:, locs]
3226+
s = groupby(obj, self.grouper)
3227+
result = s.aggregate(lambda x: alt(x, axis=self.axis))
3228+
result = result._data.blocks[0]
31773229

31783230
# see if we can cast the block back to the original dtype
31793231
result = block._try_coerce_and_cast_result(result)
31803232

3181-
newb = make_block(result, placement=block.mgr_locs)
3233+
new_items.append(locs)
3234+
newb = block.make_block_same_class(
3235+
result,
3236+
placement=np.arange(offset, offset + len(locs)))
3237+
offset += len(locs)
31823238
new_blocks.append(newb)
31833239

31843240
if len(new_blocks) == 0:
31853241
raise DataError('No numeric types to aggregate')
31863242

3187-
return data.items, new_blocks
3243+
return data.items.take(np.concatenate(new_items)), new_blocks
31883244

31893245
def _get_data_to_aggregate(self):
31903246
obj = self._obj_with_exclusions

pandas/tests/groupby/test_groupby.py

+81
Original file line numberDiff line numberDiff line change
@@ -2260,6 +2260,86 @@ def test_max_min_non_numeric(self):
22602260
result = aa.groupby('nn').min()
22612261
self.assertTrue('ss' in result)
22622262

2263+
def test_arg_passthru(self):
2264+
# make sure that we are passing thru kwargs
2265+
# to our agg functions
2266+
2267+
# GH3668
2268+
# GH5724
2269+
df = pd.DataFrame({
2270+
'group': [1, 1, 2],
2271+
'int': [1, 2, 3],
2272+
'float': [1., 2., 3.],
2273+
'string': list('abc'),
2274+
'category': pd.Series(list('abc')).astype('category'),
2275+
'datetime': pd.date_range('20130101', periods=3),
2276+
'datetimetz': pd.date_range('20130101',
2277+
periods=3,
2278+
tz='US/Eastern'),
2279+
'timedelta': pd.timedelta_range('1 s', periods=3, freq='s')})
2280+
2281+
# basic
2282+
result = df.groupby('group').mean()
2283+
expected = pd.DataFrame(
2284+
{'int': [1.5, 3],
2285+
'float': [1.5, 3.]},
2286+
index=Index([1, 2], name='group'))
2287+
assert_frame_equal(result.reindex_like(expected), expected)
2288+
2289+
# mean / median
2290+
expected = pd.DataFrame(
2291+
{'int': [1.5, 3],
2292+
'float': [1.5, 3.],
2293+
'timedelta': [pd.Timedelta('1.5s'),
2294+
pd.Timedelta('3s')],
2295+
'datetime': [pd.Timestamp('2013-01-01 12:00:00'),
2296+
pd.Timestamp('2013-01-03 00:00:00')],
2297+
'datetimetz': [
2298+
pd.Timestamp('2013-01-01 12:00:00', tz='US/Eastern'),
2299+
pd.Timestamp('2013-01-03 00:00:00', tz='US/Eastern')]},
2300+
index=Index([1, 2], name='group'))
2301+
for attr in ['mean', 'median']:
2302+
f = getattr(df.groupby('group'), attr)
2303+
result = f(numeric_only=False)
2304+
assert_frame_equal(result, expected)
2305+
2306+
expected_columns = Index(['datetime', 'datetimetz',
2307+
'float', 'int',
2308+
'string', 'timedelta'])
2309+
2310+
# TODO: min, max *should*
2311+
# categorical (ordered) dtype
2312+
for attr in ['min', 'max']:
2313+
f = getattr(df.groupby('group'), attr)
2314+
result = f(numeric_only=False)
2315+
tm.assert_index_equal(result.columns, expected_columns)
2316+
2317+
expected_columns = Index(['category', 'datetime', 'datetimetz',
2318+
'float', 'int',
2319+
'string', 'timedelta'])
2320+
for attr in ['first', 'last']:
2321+
f = getattr(df.groupby('group'), attr)
2322+
result = f(numeric_only=False)
2323+
tm.assert_index_equal(result.columns, expected_columns)
2324+
2325+
expected_columns = Index(['float', 'int', 'string', 'timedelta'])
2326+
for attr in ['sum']:
2327+
f = getattr(df.groupby('group'), attr)
2328+
result = f(numeric_only=False)
2329+
tm.assert_index_equal(result.columns, expected_columns)
2330+
2331+
expected_columns = Index(['float', 'int'])
2332+
for attr in ['prod', 'cumprod']:
2333+
f = getattr(df.groupby('group'), attr)
2334+
result = f(numeric_only=False)
2335+
tm.assert_index_equal(result.columns, expected_columns)
2336+
2337+
expected_columns = Index(['float', 'int', 'timedelta'])
2338+
for attr in ['cumsum']:
2339+
f = getattr(df.groupby('group'), attr)
2340+
result = f(numeric_only=False)
2341+
tm.assert_index_equal(result.columns, expected_columns)
2342+
22632343
def test_cython_agg_boolean(self):
22642344
frame = DataFrame({'a': np.random.randint(0, 5, 50),
22652345
'b': np.random.randint(0, 2, 50).astype('bool')})
@@ -3436,6 +3516,7 @@ def test_int64_overflow(self):
34363516
tups = list(map(tuple, df[['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'
34373517
]].values))
34383518
tups = com._asarray_tuplesafe(tups)
3519+
34393520
expected = df.groupby(tups).sum()['values']
34403521

34413522
for k, v in compat.iteritems(expected):

0 commit comments

Comments
 (0)