Skip to content

Commit 638b0ad

Browse files
topper-123victor
authored and
victor
committed
BUG: df.agg, df.transform and df.apply use different methods when axis=1 than when axis=0 (pandas-dev#21224)
1 parent 10a95d5 commit 638b0ad

File tree

8 files changed

+338
-98
lines changed

8 files changed

+338
-98
lines changed

doc/source/whatsnew/v0.24.0.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,9 @@ Numeric
475475
- Bug in :class:`Series` ``__rmatmul__`` doesn't support matrix vector multiplication (:issue:`21530`)
476476
- Bug in :func:`factorize` fails with read-only array (:issue:`12813`)
477477
- Fixed bug in :func:`unique` handled signed zeros inconsistently: for some inputs 0.0 and -0.0 were treated as equal and for some inputs as different. Now they are treated as equal for all inputs (:issue:`21866`)
478-
-
478+
- Bug in :meth:`DataFrame.agg`, :meth:`DataFrame.transform` and :meth:`DataFrame.apply` where,
479+
when supplied with a list of functions and ``axis=1`` (e.g. ``df.apply(['sum', 'mean'], axis=1)``),
480+
a ``TypeError`` was wrongly raised. For all three methods such calculation are now done correctly. (:issue:`16679`).
479481
-
480482

481483
Strings

pandas/conftest.py

+55
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ def spmatrix(request):
6060
return getattr(sparse, request.param + '_matrix')
6161

6262

63+
@pytest.fixture(params=[0, 1, 'index', 'columns'],
64+
ids=lambda x: "axis {!r}".format(x))
65+
def axis(request):
66+
"""
67+
Fixture for returning the axis numbers of a DataFrame.
68+
"""
69+
return request.param
70+
71+
72+
axis_frame = axis
73+
74+
75+
@pytest.fixture(params=[0, 'index'], ids=lambda x: "axis {!r}".format(x))
76+
def axis_series(request):
77+
"""
78+
Fixture for returning the axis numbers of a Series.
79+
"""
80+
return request.param
81+
82+
6383
@pytest.fixture
6484
def ip():
6585
"""
@@ -103,6 +123,41 @@ def all_arithmetic_operators(request):
103123
return request.param
104124

105125

126+
# use sorted as dicts in py<3.6 have random order, which xdist doesn't like
127+
_cython_table = sorted(((key, value) for key, value in
128+
pd.core.base.SelectionMixin._cython_table.items()),
129+
key=lambda x: x[0].__class__.__name__)
130+
131+
132+
@pytest.fixture(params=_cython_table)
133+
def cython_table_items(request):
134+
return request.param
135+
136+
137+
def _get_cython_table_params(ndframe, func_names_and_expected):
138+
"""combine frame, functions from SelectionMixin._cython_table
139+
keys and expected result.
140+
141+
Parameters
142+
----------
143+
ndframe : DataFrame or Series
144+
func_names_and_expected : Sequence of two items
145+
The first item is a name of a NDFrame method ('sum', 'prod') etc.
146+
The second item is the expected return value
147+
148+
Returns
149+
-------
150+
results : list
151+
List of three items (DataFrame, function, expected result)
152+
"""
153+
results = []
154+
for func_name, expected in func_names_and_expected:
155+
results.append((ndframe, func_name, expected))
156+
results += [(ndframe, func, expected) for func, name in _cython_table
157+
if name == func_name]
158+
return results
159+
160+
106161
@pytest.fixture(params=['__eq__', '__ne__', '__le__',
107162
'__lt__', '__ge__', '__gt__'])
108163
def all_compare_operators(request):

pandas/core/apply.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pandas.core.dtypes.generic import ABCSeries
66
from pandas.core.dtypes.common import (
77
is_extension_type,
8+
is_dict_like,
9+
is_list_like,
810
is_sequence)
911
from pandas.util._decorators import cache_readonly
1012

@@ -105,6 +107,11 @@ def agg_axis(self):
105107
def get_result(self):
106108
""" compute the results """
107109

110+
# dispatch to agg
111+
if is_list_like(self.f) or is_dict_like(self.f):
112+
return self.obj.aggregate(self.f, axis=self.axis,
113+
*self.args, **self.kwds)
114+
108115
# all empty
109116
if len(self.columns) == 0 and len(self.index) == 0:
110117
return self.apply_empty_result()
@@ -308,15 +315,6 @@ def wrap_results(self):
308315
class FrameRowApply(FrameApply):
309316
axis = 0
310317

311-
def get_result(self):
312-
313-
# dispatch to agg
314-
if isinstance(self.f, (list, dict)):
315-
return self.obj.aggregate(self.f, axis=self.axis,
316-
*self.args, **self.kwds)
317-
318-
return super(FrameRowApply, self).get_result()
319-
320318
def apply_broadcast(self):
321319
return super(FrameRowApply, self).apply_broadcast(self.obj)
322320

pandas/core/frame.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -6070,19 +6070,34 @@ def _gotitem(self,
60706070
def aggregate(self, func, axis=0, *args, **kwargs):
60716071
axis = self._get_axis_number(axis)
60726072

6073-
# TODO: flipped axis
60746073
result = None
6075-
if axis == 0:
6076-
try:
6077-
result, how = self._aggregate(func, axis=0, *args, **kwargs)
6078-
except TypeError:
6079-
pass
6074+
try:
6075+
result, how = self._aggregate(func, axis=axis, *args, **kwargs)
6076+
except TypeError:
6077+
pass
60806078
if result is None:
60816079
return self.apply(func, axis=axis, args=args, **kwargs)
60826080
return result
60836081

6082+
def _aggregate(self, arg, axis=0, *args, **kwargs):
6083+
if axis == 1:
6084+
# NDFrame.aggregate returns a tuple, and we need to transpose
6085+
# only result
6086+
result, how = (super(DataFrame, self.T)
6087+
._aggregate(arg, *args, **kwargs))
6088+
result = result.T if result is not None else result
6089+
return result, how
6090+
return super(DataFrame, self)._aggregate(arg, *args, **kwargs)
6091+
60846092
agg = aggregate
60856093

6094+
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
6095+
def transform(self, func, axis=0, *args, **kwargs):
6096+
axis = self._get_axis_number(axis)
6097+
if axis == 1:
6098+
return super(DataFrame, self.T).transform(func, *args, **kwargs).T
6099+
return super(DataFrame, self).transform(func, *args, **kwargs)
6100+
60866101
def apply(self, func, axis=0, broadcast=None, raw=False, reduce=None,
60876102
result_type=None, args=(), **kwds):
60886103
"""

pandas/core/generic.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -9193,16 +9193,14 @@ def ewm(self, com=None, span=None, halflife=None, alpha=None,
91939193

91949194
cls.ewm = ewm
91959195

9196-
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
9197-
def transform(self, func, *args, **kwargs):
9198-
result = self.agg(func, *args, **kwargs)
9199-
if is_scalar(result) or len(result) != len(self):
9200-
raise ValueError("transforms cannot produce "
9201-
"aggregated results")
9196+
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
9197+
def transform(self, func, *args, **kwargs):
9198+
result = self.agg(func, *args, **kwargs)
9199+
if is_scalar(result) or len(result) != len(self):
9200+
raise ValueError("transforms cannot produce "
9201+
"aggregated results")
92029202

9203-
return result
9204-
9205-
cls.transform = transform
9203+
return result
92069204

92079205
# ----------------------------------------------------------------------
92089206
# Misc methods

0 commit comments

Comments
 (0)