Skip to content

Commit 4b5b2c3

Browse files
author
tp
committed
correct apply(axis=1) and related bugs
1 parent 17dc5b9 commit 4b5b2c3

File tree

5 files changed

+78
-50
lines changed

5 files changed

+78
-50
lines changed

doc/source/whatsnew/v0.24.0.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,11 @@ Numeric
405405

406406
- Bug in :class:`Series` ``__rmatmul__`` doesn't support matrix vector multiplication (:issue:`21530`)
407407
- Bug in :func:`factorize` fails with read-only array (:issue:`12813`)
408-
-
408+
- Bug in :meth:`DataFrame.agg`, :meth:`DataFrame.transform` and :meth:`DataFrame.apply` when ``axis=1``.
409+
Using ``apply`` with a list on functions and axis=1 (e.g. ``df.apply(['abs'], axis=1)``)
410+
previously gave an TypeError, while the operation worked with axis=0. This fixes that issue.
411+
As ``agg`` and ``transform`` in many cases delegated to ``apply``, this also
412+
fixed this issue for them also (:issue:`16679`).
409413
-
410414

411415
Strings

pandas/core/apply.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def agg_axis(self):
105105
def get_result(self):
106106
""" compute the results """
107107

108+
# dispatch to agg
109+
if isinstance(self.f, (list, dict)):
110+
return self.obj.aggregate(self.f, axis=self.axis,
111+
*self.args, **self.kwds)
112+
108113
# all empty
109114
if len(self.columns) == 0 and len(self.index) == 0:
110115
return self.apply_empty_result()
@@ -308,15 +313,6 @@ def wrap_results(self):
308313
class FrameRowApply(FrameApply):
309314
axis = 0
310315

311-
def get_result(self):
312-
313-
# dispatch to agg
314-
if isinstance(self.f, (list, dict)):
315-
return self.obj.aggregate(self.f, axis=self.axis,
316-
*self.args, **self.kwds)
317-
318-
return super(FrameRowApply, self).get_result()
319-
320316
def apply_broadcast(self):
321317
return super(FrameRowApply, self).apply_broadcast(self.obj)
322318

pandas/core/frame.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -6077,11 +6077,20 @@ def aggregate(self, func, axis=0, *args, **kwargs):
60776077
return result
60786078

60796079
def _aggregate(self, arg, axis=0, *args, **kwargs):
6080-
obj = self.T if axis == 1 else self
6081-
return super(DataFrame, obj)._aggregate(arg, *args, **kwargs)
6080+
if axis == 1:
6081+
result, how = (super(DataFrame, self.T)
6082+
._aggregate(arg, *args, **kwargs))
6083+
result = result.T if result is not None else result
6084+
return result, how
6085+
return super(DataFrame, self)._aggregate(arg, *args, **kwargs)
60826086

60836087
agg = aggregate
60846088

6089+
def transform(self, func, axis=0, *args, **kwargs):
6090+
if axis == 1:
6091+
return super(DataFrame, self.T).transform(func, *args, **kwargs).T
6092+
return super(DataFrame, self).transform(func, *args, **kwargs)
6093+
60856094
def apply(self, func, axis=0, broadcast=None, raw=False, reduce=None,
60866095
result_type=None, args=(), **kwds):
60876096
"""

pandas/core/generic.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -9090,16 +9090,14 @@ def ewm(self, com=None, span=None, halflife=None, alpha=None,
90909090

90919091
cls.ewm = ewm
90929092

9093-
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
9094-
def transform(self, func, *args, **kwargs):
9095-
result = self.agg(func, *args, **kwargs)
9096-
if is_scalar(result) or len(result) != len(self):
9097-
raise ValueError("transforms cannot produce "
9098-
"aggregated results")
9093+
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
9094+
def transform(self, func, *args, **kwargs):
9095+
result = self.agg(func, *args, **kwargs)
9096+
if is_scalar(result) or len(result) != len(self):
9097+
raise ValueError("transforms cannot produce "
9098+
"aggregated results")
90999099

9100-
return result
9101-
9102-
cls.transform = transform
9100+
return result
91039101

91049102
# ----------------------------------------------------------------------
91059103
# Misc methods

pandas/tests/frame/test_apply.py

+50-29
Original file line numberDiff line numberDiff line change
@@ -846,58 +846,74 @@ def test_consistency_for_boxed(self, box):
846846
assert_frame_equal(result, expected)
847847

848848

849-
def zip_frames(*frames):
849+
def zip_frames(*frames, axis=1):
850850
"""
851-
take a list of frames, zip the columns together for each
852-
assume that these all have the first frame columns
851+
take a list of frames, zip them together under the
852+
assumption that these all have the first frames' index/columns.
853853
854-
return a new frame
854+
Returns
855+
-------
856+
new_frame : DataFrame
855857
"""
856-
columns = frames[0].columns
857-
zipped = [f[c] for c in columns for f in frames]
858-
return pd.concat(zipped, axis=1)
858+
if axis == 1:
859+
columns = frames[0].columns
860+
zipped = [f.loc[:, c] for c in columns for f in frames]
861+
return pd.concat(zipped, axis=1)
862+
else:
863+
index = frames[0].index
864+
zipped = [f.loc[i, :] for i in index for f in frames]
865+
return pd.DataFrame(zipped)
859866

860867

861868
class TestDataFrameAggregate(TestData):
862869

863-
def test_agg_transform(self):
870+
def test_agg_transform(self, axis):
871+
other_axis = abs(axis - 1)
864872

865873
with np.errstate(all='ignore'):
866874

867-
f_sqrt = np.sqrt(self.frame)
868875
f_abs = np.abs(self.frame)
876+
f_sqrt = np.sqrt(self.frame)
869877

870878
# ufunc
871-
result = self.frame.transform(np.sqrt)
879+
result = self.frame.transform(np.sqrt, axis=axis)
872880
expected = f_sqrt.copy()
873881
assert_frame_equal(result, expected)
874882

875-
result = self.frame.apply(np.sqrt)
883+
result = self.frame.apply(np.sqrt, axis=axis)
876884
assert_frame_equal(result, expected)
877885

878-
result = self.frame.transform(np.sqrt)
886+
result = self.frame.transform(np.sqrt, axis=axis)
879887
assert_frame_equal(result, expected)
880888

881889
# list-like
882-
result = self.frame.apply([np.sqrt])
890+
result = self.frame.apply([np.sqrt], axis=axis)
883891
expected = f_sqrt.copy()
884-
expected.columns = pd.MultiIndex.from_product(
885-
[self.frame.columns, ['sqrt']])
892+
if axis == 0:
893+
expected.columns = pd.MultiIndex.from_product(
894+
[self.frame.columns, ['sqrt']])
895+
else:
896+
expected.index = pd.MultiIndex.from_product(
897+
[self.frame.index, ['sqrt']])
886898
assert_frame_equal(result, expected)
887899

888-
result = self.frame.transform([np.sqrt])
900+
result = self.frame.transform([np.sqrt], axis=axis)
889901
assert_frame_equal(result, expected)
890902

891903
# multiple items in list
892904
# these are in the order as if we are applying both
893905
# functions per series and then concatting
894-
expected = zip_frames(f_sqrt, f_abs)
895-
expected.columns = pd.MultiIndex.from_product(
896-
[self.frame.columns, ['sqrt', 'absolute']])
897-
result = self.frame.apply([np.sqrt, np.abs])
906+
result = self.frame.apply([np.abs, np.sqrt], axis=axis)
907+
expected = zip_frames(f_abs, f_sqrt, axis=other_axis)
908+
if axis == 0:
909+
expected.columns = pd.MultiIndex.from_product(
910+
[self.frame.columns, ['absolute', 'sqrt']])
911+
else:
912+
expected.index = pd.MultiIndex.from_product(
913+
[self.frame.index, ['absolute', 'sqrt']])
898914
assert_frame_equal(result, expected)
899915

900-
result = self.frame.transform(['sqrt', np.abs])
916+
result = self.frame.transform([np.abs, 'sqrt'], axis=axis)
901917
assert_frame_equal(result, expected)
902918

903919
def test_transform_and_agg_err(self, axis):
@@ -985,13 +1001,16 @@ def test_agg_dict_nested_renaming_depr(self):
9851001

9861002
def test_agg_reduce(self, axis):
9871003
other_axis = abs(axis - 1)
988-
name1, name2 = self.frame.axes[other_axis].unique()[:2]
1004+
name1, name2 = self.frame.axes[other_axis].unique()[:2].sort_values()
9891005

9901006
# all reducers
991-
expected = zip_frames(self.frame.mean(axis=axis).to_frame(),
992-
self.frame.max(axis=axis).to_frame(),
993-
self.frame.sum(axis=axis).to_frame()).T
994-
expected.index = ['mean', 'max', 'sum']
1007+
expected = pd.concat([self.frame.mean(axis=axis),
1008+
self.frame.max(axis=axis),
1009+
self.frame.sum(axis=axis),
1010+
], axis=1)
1011+
expected.columns = ['mean', 'max', 'sum']
1012+
expected = expected.T if axis == 0 else expected
1013+
9951014
result = self.frame.agg(['mean', 'max', 'sum'], axis=axis)
9961015
assert_frame_equal(result, expected)
9971016

@@ -1001,7 +1020,7 @@ def test_agg_reduce(self, axis):
10011020
expected = Series([self.frame.loc(other_axis)[name1].mean(),
10021021
self.frame.loc(other_axis)[name2].sum()],
10031022
index=[name1, name2])
1004-
assert_series_equal(result.reindex_like(expected), expected)
1023+
assert_series_equal(result, expected)
10051024

10061025
# dict input with lists
10071026
func = {name1: ['mean'], name2: ['sum']}
@@ -1011,7 +1030,8 @@ def test_agg_reduce(self, axis):
10111030
index=['mean']),
10121031
name2: Series([self.frame.loc(other_axis)[name2].sum()],
10131032
index=['sum'])})
1014-
assert_frame_equal(result.reindex_like(expected), expected)
1033+
expected = expected.T if axis == 1 else expected
1034+
assert_frame_equal(result, expected)
10151035

10161036
# dict input with lists with multiple
10171037
func = {name1: ['mean', 'sum'],
@@ -1024,7 +1044,8 @@ def test_agg_reduce(self, axis):
10241044
name2: Series([self.frame.loc(other_axis)[name2].sum(),
10251045
self.frame.loc(other_axis)[name2].max()],
10261046
index=['sum', 'max'])})
1027-
assert_frame_equal(result.reindex_like(expected), expected)
1047+
expected = expected.T if axis == 1 else expected
1048+
assert_frame_equal(result, expected)
10281049

10291050
def test_nuiscance_columns(self):
10301051

0 commit comments

Comments
 (0)