Skip to content

Commit 5ec7e18

Browse files
author
tp
committed
add tests for func input to .agg to TestDataFrameAggregate and TestSeriesAggregate
1 parent 39e2e59 commit 5ec7e18

File tree

8 files changed

+100
-81
lines changed

8 files changed

+100
-81
lines changed

doc/source/whatsnew/v0.23.1.txt

+1
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,4 @@ Numeric
9999
^^^^^^^
100100

101101
- :meth:`~DataFrame.agg` now correctly handles numpy NaN-aware methods like :meth:`numpy.nansum` (:issue:`19629`)
102+
- :meth:`~DataFrame.agg` now correctly handles built-in methods like ``sum`` when axis=1 (:issue:`21134`)

pandas/conftest.py

+17
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,20 @@ def tz_aware_fixture(request):
149149
Fixture for trying explicit timezones: {0}
150150
"""
151151
return request.param
152+
153+
154+
@pytest.fixture(
155+
# params: Python 3.5 randomizes dict access and xdist doesn't like that
156+
# in fixtures. In order to get predetermined values we need to sort
157+
# the list deterministically
158+
# GH 21123
159+
params=list(sorted(pd.core.base.SelectionMixin._cython_table.items(),
160+
key=lambda x: x[0].__name__)),
161+
ids=lambda x: "({}-{!r})".format(x[0].__name__, x[1]),
162+
)
163+
def cython_table_items(request):
164+
"""
165+
Fixture for returning the items in
166+
pandas.core.base.SelectionMixin._cython_table
167+
"""
168+
return request.param

pandas/core/base.py

+30-26
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,14 @@ def _try_aggregate_string_function(self, arg, *args, **kwargs):
331331

332332
raise ValueError("{arg} is an unknown string function".format(arg=arg))
333333

334-
def _aggregate(self, arg, *args, **kwargs):
334+
def _aggregate(self, arg, axis=0, *args, **kwargs):
335335
"""
336336
provide an implementation for the aggregators
337337
338338
Parameters
339339
----------
340340
arg : string, dict, function
341+
axis : int
341342
*args : args to pass on to the function
342343
**kwargs : kwargs to pass on to the function
343344
@@ -350,25 +351,26 @@ def _aggregate(self, arg, *args, **kwargs):
350351
how can be a string describe the required post-processing, or
351352
None if not required
352353
"""
354+
obj = self if axis == 0 else self.T
353355
is_aggregator = lambda x: isinstance(x, (list, tuple, dict))
354356
is_nested_renamer = False
355357

356358
_axis = kwargs.pop('_axis', None)
357359
if _axis is None:
358-
_axis = getattr(self, 'axis', 0)
360+
_axis = getattr(obj, 'axis', 0)
359361
_level = kwargs.pop('_level', None)
360362

361363
if isinstance(arg, compat.string_types):
362-
return self._try_aggregate_string_function(arg, *args,
363-
**kwargs), None
364+
return obj._try_aggregate_string_function(arg, *args,
365+
**kwargs), None
364366

365367
if isinstance(arg, dict):
366368

367369
# aggregate based on the passed dict
368370
if _axis != 0: # pragma: no cover
369371
raise ValueError('Can only pass dict with axis=0')
370372

371-
obj = self._selected_obj
373+
selected_obj = obj._selected_obj
372374

373375
def nested_renaming_depr(level=4):
374376
# deprecation of nested renaming
@@ -403,16 +405,16 @@ def nested_renaming_depr(level=4):
403405
if isinstance(v, dict):
404406
is_nested_renamer = True
405407

406-
if k not in obj.columns:
408+
if k not in selected_obj.columns:
407409
msg = ('cannot perform renaming for {key} with a '
408410
'nested dictionary').format(key=k)
409411
raise SpecificationError(msg)
410412
nested_renaming_depr(4 + (_level or 0))
411413

412-
elif isinstance(obj, ABCSeries):
414+
elif isinstance(selected_obj, ABCSeries):
413415
nested_renaming_depr()
414-
elif isinstance(obj, ABCDataFrame) and \
415-
k not in obj.columns:
416+
elif isinstance(selected_obj, ABCDataFrame) and \
417+
k not in selected_obj.columns:
416418
raise KeyError(
417419
"Column '{col}' does not exist!".format(col=k))
418420

@@ -422,8 +424,8 @@ def nested_renaming_depr(level=4):
422424
# deprecation of renaming keys
423425
# GH 15931
424426
keys = list(compat.iterkeys(arg))
425-
if (isinstance(obj, ABCDataFrame) and
426-
len(obj.columns.intersection(keys)) != len(keys)):
427+
if (isinstance(selected_obj, ABCDataFrame) and len(
428+
selected_obj.columns.intersection(keys)) != len(keys)):
427429
nested_renaming_depr()
428430

429431
from pandas.core.reshape.concat import concat
@@ -432,7 +434,7 @@ def _agg_1dim(name, how, subset=None):
432434
"""
433435
aggregate a 1-dim with how
434436
"""
435-
colg = self._gotitem(name, ndim=1, subset=subset)
437+
colg = obj._gotitem(name, ndim=1, subset=subset)
436438
if colg.ndim != 1:
437439
raise SpecificationError("nested dictionary is ambiguous "
438440
"in aggregation")
@@ -442,8 +444,8 @@ def _agg_2dim(name, how):
442444
"""
443445
aggregate a 2-dim with how
444446
"""
445-
colg = self._gotitem(self._selection, ndim=2,
446-
subset=obj)
447+
colg = obj._gotitem(obj._selection, ndim=2,
448+
subset=selected_obj)
447449
return colg.aggregate(how, _level=None)
448450

449451
def _agg(arg, func):
@@ -473,20 +475,22 @@ def _agg(arg, func):
473475

474476
else:
475477

476-
if self._selection is not None:
478+
if obj._selection is not None:
477479
keys = None
478480

479481
# some selection on the object
480-
elif self._selection is not None:
482+
elif obj._selection is not None:
481483

482-
sl = set(self._selection_list)
484+
sl = set(obj._selection_list)
483485

484486
# we are a Series like object,
485487
# but may have multiple aggregations
486488
if len(sl) == 1:
487489

488-
result = _agg(arg, lambda fname,
489-
agg_how: _agg_1dim(self._selection, agg_how))
490+
result = _agg(
491+
arg,
492+
lambda fname, agg_how: _agg_1dim(
493+
obj._selection, agg_how))
490494

491495
# we are selecting the same set as we are aggregating
492496
elif not len(sl - set(keys)):
@@ -531,7 +535,7 @@ def is_any_frame():
531535
return concat([result[k] for k in keys],
532536
keys=keys, axis=1), True
533537

534-
elif isinstance(self, ABCSeries) and is_any_series():
538+
elif isinstance(obj, ABCSeries) and is_any_series():
535539

536540
# we have a dict of Series
537541
# return a MI Series
@@ -556,20 +560,20 @@ def is_any_frame():
556560

557561
# we have a dict of scalars
558562
result = Series(result,
559-
name=getattr(self, 'name', None))
563+
name=getattr(obj, 'name', None))
560564

561565
return result, True
562566
elif is_list_like(arg) and arg not in compat.string_types:
563567
# we require a list, but not an 'str'
564-
return self._aggregate_multiple_funcs(arg,
565-
_level=_level,
566-
_axis=_axis), None
568+
return obj._aggregate_multiple_funcs(arg,
569+
_level=_level,
570+
_axis=_axis), None
567571
else:
568572
result = None
569573

570-
f = self._is_cython_func(arg)
574+
f = obj._is_cython_func(arg)
571575
if f is not None:
572-
return getattr(self, f)(*args, **kwargs), None
576+
return getattr(obj, f)(*args, **kwargs), None
573577

574578
# caller can react
575579
return result, True

pandas/core/frame.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -5818,13 +5818,11 @@ def _gotitem(self,
58185818
def aggregate(self, func, axis=0, *args, **kwargs):
58195819
axis = self._get_axis_number(axis)
58205820

5821-
# TODO: flipped axis
58225821
result = None
5823-
if axis == 0:
5824-
try:
5825-
result, how = self._aggregate(func, axis=0, *args, **kwargs)
5826-
except TypeError:
5827-
pass
5822+
try:
5823+
result, how = self._aggregate(func, axis=axis, *args, **kwargs)
5824+
except TypeError:
5825+
pass
58285826
if result is None:
58295827
return self.apply(func, axis=axis, args=args, **kwargs)
58305828
return result

pandas/core/groupby/groupby.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4086,7 +4086,10 @@ def _post_process_cython_aggregate(self, obj):
40864086
def aggregate(self, arg, *args, **kwargs):
40874087

40884088
_level = kwargs.pop('_level', None)
4089-
result, how = self._aggregate(arg, _level=_level, *args, **kwargs)
4089+
_agg_kwargs = kwargs.copy()
4090+
axis = _agg_kwargs.pop('axis', 0)
4091+
result, how = self._aggregate(arg, axis, _level=_level,
4092+
*args, **_agg_kwargs)
40904093
if how is None:
40914094
return result
40924095

pandas/tests/frame/test_apply.py

+23
Original file line numberDiff line numberDiff line change
@@ -1056,3 +1056,26 @@ def test_non_callable_aggregates(self):
10561056
expected = df.size
10571057

10581058
assert result == expected
1059+
1060+
@pytest.mark.parametrize("df", [
1061+
pd.DataFrame([[1, 2], [3, 4]]),
1062+
pd.DataFrame([[np.nan, 2], [3, 4]]),
1063+
pd.DataFrame(),
1064+
])
1065+
def test_agg_function_input(self, df, cython_table_items):
1066+
# test whether the functions (keys) in
1067+
# pd.core.base.SelectionMixin._cython_table give the same result
1068+
# as the related strings (values) when used in df.agg. Examples:
1069+
# - ``df.agg(np.nansum)`` should give the same result as
1070+
# ``df.agg('sum')``
1071+
# - ``df.agg(sum)`` should give the same result as ``df.agg('sum')``
1072+
# etc.
1073+
# GH21123
1074+
np_func, str_func = cython_table_items
1075+
1076+
tm.assert_almost_equal(df.agg(np_func),
1077+
df.agg(str_func),
1078+
)
1079+
tm.assert_almost_equal(df.agg(np_func, axis=1),
1080+
df.agg(str_func, axis=1),
1081+
)

pandas/tests/series/test_apply.py

+20
Original file line numberDiff line numberDiff line change
@@ -587,3 +587,23 @@ def test_map_missing_mixed(self, vals, mapping, exp):
587587
result = s.map(mapping)
588588

589589
tm.assert_series_equal(result, pd.Series(exp))
590+
591+
@pytest.mark.parametrize("series", [
592+
pd.Series([1, 2, 3, 4]),
593+
pd.Series([np.nan, 2, 3, 4]),
594+
pd.Series(),
595+
])
596+
def test_agg_function_input(self, series, cython_table_items):
597+
# test whether the functions (keys) in
598+
# pd.core.base.SelectionMixin._cython_table give the same result
599+
# as the related strings (values), when used in ser.agg. Examples:
600+
# - ``ser.agg(np.nansum)`` should give the same result as
601+
# ``ser.agg('sum')``
602+
# - ``ser.agg(sum)`` should give the same result as ``ser.agg('sum')``
603+
# etc.
604+
# GH21123
605+
np_func, str_func = cython_table_items
606+
607+
tm.assert_almost_equal(series.agg(np_func),
608+
series.agg(str_func),
609+
)

pandas/tests/test_nanops.py

+1-48
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import pandas.core.nanops as nanops
1414
import pandas.util.testing as tm
1515
import pandas.util._test_decorators as td
16-
from pandas.compat.numpy import (_np_version_under1p13, _np_version_under1p10,
17-
_np_version_under1p12)
16+
from pandas.compat.numpy import _np_version_under1p13
1817

1918
use_bn = nanops._USE_BOTTLENECK
2019

@@ -995,52 +994,6 @@ def prng(self):
995994
return np.random.RandomState(1234)
996995

997996

998-
@pytest.fixture(params=[
999-
pd.Series([1, 2, 3, 4]),
1000-
pd.DataFrame([[1, 2], [3, 4]]),
1001-
pd.Series([np.nan, 2, 3, 4]),
1002-
pd.DataFrame([[np.nan, 2], [3, 4]]),
1003-
pd.Series(),
1004-
pd.DataFrame(),
1005-
pd.Series([np.nan]),
1006-
pd.DataFrame([[np.nan]]),
1007-
])
1008-
def series_or_frame(request):
1009-
return request.param
1010-
1011-
1012-
@pytest.mark.parametrize("standard, nan_method", [
1013-
(np.sum, np.nansum),
1014-
(np.mean, np.nanmean),
1015-
(np.std, np.nanstd),
1016-
(np.var, np.nanvar),
1017-
(np.median, np.nanmedian),
1018-
(np.max, np.nanmax),
1019-
(np.min, np.nanmin),
1020-
], ids=lambda x: x.__name__)
1021-
def test_np_nan_functions(standard, nan_method, series_or_frame):
1022-
tm.assert_almost_equal(series_or_frame.agg(standard),
1023-
series_or_frame.agg(nan_method),
1024-
check_exact=True)
1025-
1026-
1027-
@pytest.mark.skipif(_np_version_under1p10, reason="requires numpy>=1.10")
1028-
def test_np_nanprod(series_or_frame):
1029-
tm.assert_almost_equal(series_or_frame.agg(np.prod),
1030-
series_or_frame.agg(np.nanprod),
1031-
check_exact=True)
1032-
1033-
1034-
@pytest.mark.skipif(_np_version_under1p12, reason="requires numpy>=1.12")
1035-
def test_np_nancumprod(series_or_frame):
1036-
funcs = [(np.cumprod, np.nancumprod),
1037-
(np.cumsum, np.nancumsum)]
1038-
for standard, nan_method in funcs:
1039-
tm.assert_almost_equal(series_or_frame.agg(standard),
1040-
series_or_frame.agg(nan_method),
1041-
check_exact=True)
1042-
1043-
1044997
def test_use_bottleneck():
1045998

1046999
if nanops._BOTTLENECK_INSTALLED:

0 commit comments

Comments
 (0)