Skip to content

Commit 0caac6b

Browse files
author
tp
committed
Fix bug where df.agg(..., axis=1) gives wrong result
1 parent bc37ea2 commit 0caac6b

File tree

7 files changed

+200
-34
lines changed

7 files changed

+200
-34
lines changed

doc/source/whatsnew/v0.23.1.txt

+5
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,8 @@ Categorical
9494
^^^^^^^^^^^
9595

9696
-
97+
98+
Numeric
99+
^^^^^^^
100+
101+
- :meth:`~DataFrame.agg` now correctly handles built-in methods like ``sum`` when axis=1 (:issue:`21134`)

pandas/conftest.py

+17
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,20 @@ def tz_aware_fixture(request):
149149
Fixture for trying explicit timezones: {0}
150150
"""
151151
return request.param
152+
153+
154+
@pytest.fixture(
155+
# params: Python 3.5 randomizes dict access and xdist doesn't like that
156+
# in fixtures. In order to get predetermined values we need to sort
157+
# the list deterministically
158+
# GH 21123
159+
params=list(sorted(pd.core.base.SelectionMixin._cython_table.items(),
160+
key=lambda x: x[0].__name__)),
161+
ids=lambda x: "({}-{!r})_fixture".format(x[0].__name__, x[1]),
162+
)
163+
def cython_table_items(request):
164+
"""
165+
Fixture for returning the items in
166+
pandas.core.base.SelectionMixin._cython_table
167+
"""
168+
return request.param

pandas/core/base.py

+31-27
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,14 @@ def _try_aggregate_string_function(self, arg, *args, **kwargs):
316316

317317
raise ValueError("{arg} is an unknown string function".format(arg=arg))
318318

319-
def _aggregate(self, arg, *args, **kwargs):
319+
def _aggregate(self, arg, axis=0, *args, **kwargs):
320320
"""
321321
provide an implementation for the aggregators
322322
323323
Parameters
324324
----------
325325
arg : string, dict, function
326+
axis : int
326327
*args : args to pass on to the function
327328
**kwargs : kwargs to pass on to the function
328329
@@ -335,25 +336,26 @@ def _aggregate(self, arg, *args, **kwargs):
335336
how can be a string describe the required post-processing, or
336337
None if not required
337338
"""
339+
obj = self if axis == 0 else self.T
338340
is_aggregator = lambda x: isinstance(x, (list, tuple, dict))
339341
is_nested_renamer = False
340342

341343
_axis = kwargs.pop('_axis', None)
342344
if _axis is None:
343-
_axis = getattr(self, 'axis', 0)
345+
_axis = getattr(obj, 'axis', 0)
344346
_level = kwargs.pop('_level', None)
345347

346348
if isinstance(arg, compat.string_types):
347-
return self._try_aggregate_string_function(arg, *args,
348-
**kwargs), None
349+
return obj._try_aggregate_string_function(arg, *args,
350+
**kwargs), None
349351

350352
if isinstance(arg, dict):
351353

352354
# aggregate based on the passed dict
353355
if _axis != 0: # pragma: no cover
354356
raise ValueError('Can only pass dict with axis=0')
355357

356-
obj = self._selected_obj
358+
selected_obj = obj._selected_obj
357359

358360
def nested_renaming_depr(level=4):
359361
# deprecation of nested renaming
@@ -388,16 +390,16 @@ def nested_renaming_depr(level=4):
388390
if isinstance(v, dict):
389391
is_nested_renamer = True
390392

391-
if k not in obj.columns:
393+
if k not in selected_obj.columns:
392394
msg = ('cannot perform renaming for {key} with a '
393395
'nested dictionary').format(key=k)
394396
raise SpecificationError(msg)
395397
nested_renaming_depr(4 + (_level or 0))
396398

397-
elif isinstance(obj, ABCSeries):
399+
elif isinstance(selected_obj, ABCSeries):
398400
nested_renaming_depr()
399-
elif isinstance(obj, ABCDataFrame) and \
400-
k not in obj.columns:
401+
elif isinstance(selected_obj, ABCDataFrame) and \
402+
k not in selected_obj.columns:
401403
raise KeyError(
402404
"Column '{col}' does not exist!".format(col=k))
403405

@@ -407,8 +409,8 @@ def nested_renaming_depr(level=4):
407409
# deprecation of renaming keys
408410
# GH 15931
409411
keys = list(compat.iterkeys(arg))
410-
if (isinstance(obj, ABCDataFrame) and
411-
len(obj.columns.intersection(keys)) != len(keys)):
412+
if (isinstance(selected_obj, ABCDataFrame) and len(
413+
selected_obj.columns.intersection(keys)) != len(keys)):
412414
nested_renaming_depr()
413415

414416
from pandas.core.reshape.concat import concat
@@ -417,7 +419,7 @@ def _agg_1dim(name, how, subset=None):
417419
"""
418420
aggregate a 1-dim with how
419421
"""
420-
colg = self._gotitem(name, ndim=1, subset=subset)
422+
colg = obj._gotitem(name, ndim=1, subset=subset)
421423
if colg.ndim != 1:
422424
raise SpecificationError("nested dictionary is ambiguous "
423425
"in aggregation")
@@ -427,8 +429,8 @@ def _agg_2dim(name, how):
427429
"""
428430
aggregate a 2-dim with how
429431
"""
430-
colg = self._gotitem(self._selection, ndim=2,
431-
subset=obj)
432+
colg = obj._gotitem(obj._selection, ndim=2,
433+
subset=selected_obj)
432434
return colg.aggregate(how, _level=None)
433435

434436
def _agg(arg, func):
@@ -458,20 +460,22 @@ def _agg(arg, func):
458460

459461
else:
460462

461-
if self._selection is not None:
463+
if obj._selection is not None:
462464
keys = None
463465

464466
# some selection on the object
465-
elif self._selection is not None:
467+
elif obj._selection is not None:
466468

467-
sl = set(self._selection_list)
469+
sl = set(obj._selection_list)
468470

469471
# we are a Series like object,
470472
# but may have multiple aggregations
471473
if len(sl) == 1:
472474

473-
result = _agg(arg, lambda fname,
474-
agg_how: _agg_1dim(self._selection, agg_how))
475+
result = _agg(
476+
arg,
477+
lambda fname, agg_how: _agg_1dim(
478+
obj._selection, agg_how))
475479

476480
# we are selecting the same set as we are aggregating
477481
elif not len(sl - set(keys)):
@@ -516,7 +520,7 @@ def is_any_frame():
516520
return concat([result[k] for k in keys],
517521
keys=keys, axis=1), True
518522

519-
elif isinstance(self, ABCSeries) and is_any_series():
523+
elif isinstance(obj, ABCSeries) and is_any_series():
520524

521525
# we have a dict of Series
522526
# return a MI Series
@@ -541,20 +545,20 @@ def is_any_frame():
541545

542546
# we have a dict of scalars
543547
result = Series(result,
544-
name=getattr(self, 'name', None))
548+
name=getattr(obj, 'name', None))
545549

546550
return result, True
547551
elif is_list_like(arg) and arg not in compat.string_types:
548552
# we require a list, but not an 'str'
549-
return self._aggregate_multiple_funcs(arg,
550-
_level=_level,
551-
_axis=_axis), None
553+
return obj._aggregate_multiple_funcs(arg,
554+
_level=_level,
555+
_axis=_axis), None
552556
else:
553557
result = None
554558

555-
f = self._is_cython_func(arg)
556-
if f and not args and not kwargs:
557-
return getattr(self, f)(), None
559+
f = obj._is_cython_func(arg)
560+
if f is not None:
561+
return getattr(obj, f)(*args, **kwargs), None
558562

559563
# caller can react
560564
return result, True

pandas/core/frame.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -5818,13 +5818,11 @@ def _gotitem(self,
58185818
def aggregate(self, func, axis=0, *args, **kwargs):
58195819
axis = self._get_axis_number(axis)
58205820

5821-
# TODO: flipped axis
58225821
result = None
5823-
if axis == 0:
5824-
try:
5825-
result, how = self._aggregate(func, axis=0, *args, **kwargs)
5826-
except TypeError:
5827-
pass
5822+
try:
5823+
result, how = self._aggregate(func, axis=axis, *args, **kwargs)
5824+
except TypeError:
5825+
pass
58285826
if result is None:
58295827
return self.apply(func, axis=axis, args=args, **kwargs)
58305828
return result

pandas/core/groupby/groupby.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4086,7 +4086,10 @@ def _post_process_cython_aggregate(self, obj):
40864086
def aggregate(self, arg, *args, **kwargs):
40874087

40884088
_level = kwargs.pop('_level', None)
4089-
result, how = self._aggregate(arg, _level=_level, *args, **kwargs)
4089+
_agg_kwargs = kwargs.copy()
4090+
axis = _agg_kwargs.pop('axis', 0)
4091+
result, how = self._aggregate(arg, axis, _level=_level,
4092+
*args, **_agg_kwargs)
40904093
if how is None:
40914094
return result
40924095

pandas/tests/frame/test_apply.py

+69
Original file line numberDiff line numberDiff line change
@@ -1056,3 +1056,72 @@ def test_non_callable_aggregates(self):
10561056
expected = df.size
10571057

10581058
assert result == expected
1059+
1060+
@pytest.mark.parametrize("frame, expected_dict", [
1061+
[DataFrame(), {
1062+
'sum': Series(),
1063+
'max': Series(),
1064+
'min': Series(),
1065+
'all': Series(dtype=bool),
1066+
'any': Series(dtype=bool),
1067+
'mean': Series(),
1068+
'prod': Series(),
1069+
'std': Series(),
1070+
'var': Series(),
1071+
'median': Series(),
1072+
'cumprod': DataFrame(),
1073+
'cumsum': DataFrame(),
1074+
}],
1075+
[DataFrame([[np.nan, 1], [1, 2]]), {
1076+
'sum': Series([1., 3]),
1077+
'max': Series([1., 2]),
1078+
'min': Series([1., 1]),
1079+
'all': Series([True, True]),
1080+
'any': Series([True, True]),
1081+
'mean': Series([1, 1.5]),
1082+
'prod': Series([1., 2]),
1083+
'std': Series([np.nan, 0.707107]),
1084+
'var': Series([np.nan, 0.5]),
1085+
'median': Series([1, 1.5]),
1086+
'cumprod': DataFrame([[np.nan, 1], [1., 2.]]),
1087+
'cumsum': DataFrame([[np.nan, 1], [1., 3.]]),
1088+
}],
1089+
[DataFrame([['a', 'b'], ['b', 'a']]), {
1090+
'sum': Series(['ab', 'ba']),
1091+
'max': Series(['b', 'b']),
1092+
'min': Series(['a', 'a']),
1093+
'all': Series([True, True]),
1094+
'any': Series([True, True]),
1095+
'mean': Series([], index=pd.Index([], dtype='int64')),
1096+
'prod': Series([], index=pd.Index([], dtype='int64')),
1097+
'std': Series([], index=pd.Index([], dtype='int64')),
1098+
'var': Series([], index=pd.Index([], dtype='int64')),
1099+
'median': Series([], index=pd.Index([], dtype='int64')),
1100+
'cumprod': TypeError,
1101+
'cumsum': DataFrame([['a', 'b'], ['ab', 'ba']]),
1102+
}],
1103+
])
1104+
@pytest.mark.parametrize("axis", [0, 1], ids=lambda x: "axis {}".format(x))
1105+
def test_agg_function_input(self, cython_table_items,
1106+
frame, expected_dict, axis):
1107+
# GH21123
1108+
# test if using items in _cython_table gives correct results
1109+
np_func, str_func = cython_table_items
1110+
expected = expected_dict[str_func]
1111+
1112+
if isinstance(expected, type) and issubclass(expected, Exception):
1113+
with pytest.raises(expected):
1114+
# e.g. DataFrame(['a b'.split()]).cumprod() will raise
1115+
frame.agg(np_func, axis=axis)
1116+
with pytest.raises(expected):
1117+
frame.agg(str_func, axis=axis)
1118+
return
1119+
1120+
result = frame.agg(np_func, axis=axis)
1121+
result_str_func = frame.agg(str_func, axis=axis)
1122+
if str_func in ('cumprod', 'cumsum'):
1123+
tm.assert_frame_equal(result, expected)
1124+
tm.assert_frame_equal(result_str_func, expected)
1125+
else:
1126+
tm.assert_series_equal(result, expected)
1127+
tm.assert_series_equal(result_str_func, expected)

pandas/tests/series/test_apply.py

+70
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,76 @@ def test_non_callable_aggregates(self):
331331
('mean', 1.5)]))
332332
assert_series_equal(result[expected.index], expected)
333333

334+
@pytest.mark.parametrize("series, expected_dict", [
335+
[Series(), {
336+
'sum': 0,
337+
'max': np.nan,
338+
'min': np.nan,
339+
'all': True,
340+
'any': False,
341+
'mean': np.nan,
342+
'prod': 1,
343+
'std': np.nan,
344+
'var': np.nan,
345+
'median': np.nan,
346+
'cumprod': Series([], Index([])),
347+
'cumsum': Series([], Index([])),
348+
}],
349+
[Series([np.nan, 1, 2, 3]), {
350+
'sum': 6,
351+
'max': 3,
352+
'min': 1,
353+
'all': True,
354+
'any': True,
355+
'mean': 2,
356+
'prod': 6,
357+
'std': 1,
358+
'var': 1,
359+
'median': 2,
360+
'cumprod': Series([np.nan, 1, 2, 6]),
361+
'cumsum': Series([np.nan, 1, 3, 6]),
362+
}],
363+
[Series('a b c'.split()), {
364+
'sum': 'abc',
365+
'max': 'c',
366+
'min': 'a',
367+
'all': 'c', # see GH12863
368+
'any': 'a',
369+
'mean': TypeError, # mean raises TypeError
370+
'prod': TypeError,
371+
'std': TypeError,
372+
'var': TypeError,
373+
'median': TypeError,
374+
'cumprod': TypeError,
375+
'cumsum': Series(['a', 'ab', 'abc']),
376+
}],
377+
])
378+
def test_agg_cython_table_input(self, cython_table_items,
379+
series, expected_dict):
380+
# GH21123
381+
# test if using items in _cython_table gives correct results
382+
np_func, str_func = cython_table_items
383+
expected = expected_dict[str_func]
384+
385+
if isinstance(expected, type) and issubclass(expected, Exception):
386+
with pytest.raises(expected):
387+
series.agg(np_func)
388+
with pytest.raises(expected):
389+
series.agg(str_func)
390+
return
391+
392+
result = series.agg(np_func)
393+
result_str_func = series.agg(str_func)
394+
if str_func in ('cumprod', 'cumsum'):
395+
tm.assert_series_equal(result, expected)
396+
tm.assert_series_equal(result_str_func, expected)
397+
elif tm.is_number(expected):
398+
assert np.isclose(result, expected, equal_nan=True)
399+
assert np.isclose(result_str_func, expected, equal_nan=True)
400+
else:
401+
assert result == expected
402+
assert result_str_func == expected
403+
334404

335405
class TestSeriesMap(TestData):
336406

0 commit comments

Comments
 (0)