Skip to content

Commit 3ac8f0b

Browse files
tptopper-123
tp
authored andcommitted
correct apply(axis=1) and related bugs
1 parent 08402f9 commit 3ac8f0b

File tree

5 files changed

+88
-59
lines changed

5 files changed

+88
-59
lines changed

doc/source/whatsnew/v0.24.0.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,11 @@ Numeric
475475
- Bug in :class:`Series` ``__rmatmul__`` doesn't support matrix vector multiplication (:issue:`21530`)
476476
- Bug in :func:`factorize` fails with read-only array (:issue:`12813`)
477477
- Fixed bug in :func:`unique` handled signed zeros inconsistently: for some inputs 0.0 and -0.0 were treated as equal and for some inputs as different. Now they are treated as equal for all inputs (:issue:`21866`)
478-
-
478+
- Bug in :meth:`DataFrame.agg`, :meth:`DataFrame.transform` and :meth:`DataFrame.apply` when ``axis=1``.
479+
Using ``apply`` with a list of functions and axis=1 (e.g. ``df.apply(['abs'], axis=1)``)
480+
previously gave a TypeError. This fixes that issue.
481+
As ``agg`` and ``transform`` in some cases delegate to ``apply``, this also
482+
fixed this issue for them (:issue:`16679`).
479483
-
480484

481485
Strings

pandas/core/apply.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def agg_axis(self):
105105
def get_result(self):
106106
""" compute the results """
107107

108+
# dispatch to agg
109+
if isinstance(self.f, (list, dict)):
110+
return self.obj.aggregate(self.f, axis=self.axis,
111+
*self.args, **self.kwds)
112+
108113
# all empty
109114
if len(self.columns) == 0 and len(self.index) == 0:
110115
return self.apply_empty_result()
@@ -308,15 +313,6 @@ def wrap_results(self):
308313
class FrameRowApply(FrameApply):
309314
axis = 0
310315

311-
def get_result(self):
312-
313-
# dispatch to agg
314-
if isinstance(self.f, (list, dict)):
315-
return self.obj.aggregate(self.f, axis=self.axis,
316-
*self.args, **self.kwds)
317-
318-
return super(FrameRowApply, self).get_result()
319-
320316
def apply_broadcast(self):
321317
return super(FrameRowApply, self).apply_broadcast(self.obj)
322318

pandas/core/frame.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -6079,11 +6079,20 @@ def aggregate(self, func, axis=0, *args, **kwargs):
60796079
return result
60806080

60816081
def _aggregate(self, arg, axis=0, *args, **kwargs):
6082-
obj = self.T if axis == 1 else self
6083-
return super(DataFrame, obj)._aggregate(arg, *args, **kwargs)
6082+
if axis == 1:
6083+
result, how = (super(DataFrame, self.T)
6084+
._aggregate(arg, *args, **kwargs))
6085+
result = result.T if result is not None else result
6086+
return result, how
6087+
return super(DataFrame, self)._aggregate(arg, *args, **kwargs)
60846088

60856089
agg = aggregate
60866090

6091+
def transform(self, func, axis=0, *args, **kwargs):
6092+
if axis == 1:
6093+
return super(DataFrame, self.T).transform(func, *args, **kwargs).T
6094+
return super(DataFrame, self).transform(func, *args, **kwargs)
6095+
60876096
def apply(self, func, axis=0, broadcast=None, raw=False, reduce=None,
60886097
result_type=None, args=(), **kwds):
60896098
"""

pandas/core/generic.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -9193,16 +9193,14 @@ def ewm(self, com=None, span=None, halflife=None, alpha=None,
91939193

91949194
cls.ewm = ewm
91959195

9196-
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
9197-
def transform(self, func, *args, **kwargs):
9198-
result = self.agg(func, *args, **kwargs)
9199-
if is_scalar(result) or len(result) != len(self):
9200-
raise ValueError("transforms cannot produce "
9201-
"aggregated results")
9196+
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
9197+
def transform(self, func, *args, **kwargs):
9198+
result = self.agg(func, *args, **kwargs)
9199+
if is_scalar(result) or len(result) != len(self):
9200+
raise ValueError("transforms cannot produce "
9201+
"aggregated results")
92029202

9203-
return result
9204-
9205-
cls.transform = transform
9203+
return result
92069204

92079205
# ----------------------------------------------------------------------
92089206
# Misc methods

pandas/tests/frame/test_apply.py

+60-38
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import operator
8+
from collections import OrderedDict
89
from datetime import datetime
910
from itertools import chain
1011

@@ -846,58 +847,74 @@ def test_consistency_for_boxed(self, box):
846847
assert_frame_equal(result, expected)
847848

848849

849-
def zip_frames(*frames):
850+
def zip_frames(frames, axis=1):
850851
"""
851-
take a list of frames, zip the columns together for each
852-
assume that these all have the first frame columns
852+
take a list of frames, zip them together under the
853+
assumption that these all have the first frames' index/columns.
853854
854-
return a new frame
855+
Returns
856+
-------
857+
new_frame : DataFrame
855858
"""
856-
columns = frames[0].columns
857-
zipped = [f[c] for c in columns for f in frames]
858-
return pd.concat(zipped, axis=1)
859+
if axis == 1:
860+
columns = frames[0].columns
861+
zipped = [f.loc[:, c] for c in columns for f in frames]
862+
return pd.concat(zipped, axis=1)
863+
else:
864+
index = frames[0].index
865+
zipped = [f.loc[i, :] for i in index for f in frames]
866+
return pd.DataFrame(zipped)
859867

860868

861869
class TestDataFrameAggregate(TestData):
862870

863-
def test_agg_transform(self):
871+
def test_agg_transform(self, axis):
872+
other_axis = abs(axis - 1)
864873

865874
with np.errstate(all='ignore'):
866875

867-
f_sqrt = np.sqrt(self.frame)
868876
f_abs = np.abs(self.frame)
877+
f_sqrt = np.sqrt(self.frame)
869878

870879
# ufunc
871-
result = self.frame.transform(np.sqrt)
880+
result = self.frame.transform(np.sqrt, axis=axis)
872881
expected = f_sqrt.copy()
873882
assert_frame_equal(result, expected)
874883

875-
result = self.frame.apply(np.sqrt)
884+
result = self.frame.apply(np.sqrt, axis=axis)
876885
assert_frame_equal(result, expected)
877886

878-
result = self.frame.transform(np.sqrt)
887+
result = self.frame.transform(np.sqrt, axis=axis)
879888
assert_frame_equal(result, expected)
880889

881890
# list-like
882-
result = self.frame.apply([np.sqrt])
891+
result = self.frame.apply([np.sqrt], axis=axis)
883892
expected = f_sqrt.copy()
884-
expected.columns = pd.MultiIndex.from_product(
885-
[self.frame.columns, ['sqrt']])
893+
if axis == 0:
894+
expected.columns = pd.MultiIndex.from_product(
895+
[self.frame.columns, ['sqrt']])
896+
else:
897+
expected.index = pd.MultiIndex.from_product(
898+
[self.frame.index, ['sqrt']])
886899
assert_frame_equal(result, expected)
887900

888-
result = self.frame.transform([np.sqrt])
901+
result = self.frame.transform([np.sqrt], axis=axis)
889902
assert_frame_equal(result, expected)
890903

891904
# multiple items in list
892905
# these are in the order as if we are applying both
893906
# functions per series and then concatting
894-
expected = zip_frames(f_sqrt, f_abs)
895-
expected.columns = pd.MultiIndex.from_product(
896-
[self.frame.columns, ['sqrt', 'absolute']])
897-
result = self.frame.apply([np.sqrt, np.abs])
907+
result = self.frame.apply([np.abs, np.sqrt], axis=axis)
908+
expected = zip_frames([f_abs, f_sqrt], axis=other_axis)
909+
if axis == 0:
910+
expected.columns = pd.MultiIndex.from_product(
911+
[self.frame.columns, ['absolute', 'sqrt']])
912+
else:
913+
expected.index = pd.MultiIndex.from_product(
914+
[self.frame.index, ['absolute', 'sqrt']])
898915
assert_frame_equal(result, expected)
899916

900-
result = self.frame.transform(['sqrt', np.abs])
917+
result = self.frame.transform([np.abs, 'sqrt'], axis=axis)
901918
assert_frame_equal(result, expected)
902919

903920
def test_transform_and_agg_err(self, axis):
@@ -985,46 +1002,51 @@ def test_agg_dict_nested_renaming_depr(self):
9851002

9861003
def test_agg_reduce(self, axis):
9871004
other_axis = abs(axis - 1)
988-
name1, name2 = self.frame.axes[other_axis].unique()[:2]
1005+
name1, name2 = self.frame.axes[other_axis].unique()[:2].sort_values()
9891006

9901007
# all reducers
991-
expected = zip_frames(self.frame.mean(axis=axis).to_frame(),
992-
self.frame.max(axis=axis).to_frame(),
993-
self.frame.sum(axis=axis).to_frame()).T
994-
expected.index = ['mean', 'max', 'sum']
1008+
expected = pd.concat([self.frame.mean(axis=axis),
1009+
self.frame.max(axis=axis),
1010+
self.frame.sum(axis=axis),
1011+
], axis=1)
1012+
expected.columns = ['mean', 'max', 'sum']
1013+
expected = expected.T if axis == 0 else expected
1014+
9951015
result = self.frame.agg(['mean', 'max', 'sum'], axis=axis)
9961016
assert_frame_equal(result, expected)
9971017

9981018
# dict input with scalars
999-
func = {name1: 'mean', name2: 'sum'}
1019+
func = OrderedDict([(name1, 'mean'), (name2, 'sum')])
10001020
result = self.frame.agg(func, axis=axis)
10011021
expected = Series([self.frame.loc(other_axis)[name1].mean(),
10021022
self.frame.loc(other_axis)[name2].sum()],
10031023
index=[name1, name2])
1004-
assert_series_equal(result.reindex_like(expected), expected)
1024+
assert_series_equal(result, expected)
10051025

10061026
# dict input with lists
1007-
func = {name1: ['mean'], name2: ['sum']}
1027+
func = OrderedDict([(name1, ['mean']), (name2, ['sum'])])
10081028
result = self.frame.agg(func, axis=axis)
10091029
expected = DataFrame({
10101030
name1: Series([self.frame.loc(other_axis)[name1].mean()],
10111031
index=['mean']),
10121032
name2: Series([self.frame.loc(other_axis)[name2].sum()],
10131033
index=['sum'])})
1014-
assert_frame_equal(result.reindex_like(expected), expected)
1034+
expected = expected.T if axis == 1 else expected
1035+
assert_frame_equal(result, expected)
10151036

10161037
# dict input with lists with multiple
1017-
func = {name1: ['mean', 'sum'],
1018-
name2: ['sum', 'max']}
1038+
func = OrderedDict([(name1, ['mean', 'sum']), (name2, ['sum', 'max'])])
10191039
result = self.frame.agg(func, axis=axis)
1020-
expected = DataFrame({
1021-
name1: Series([self.frame.loc(other_axis)[name1].mean(),
1040+
expected = DataFrame(OrderedDict([
1041+
(name1, Series([self.frame.loc(other_axis)[name1].mean(),
10221042
self.frame.loc(other_axis)[name1].sum()],
1023-
index=['mean', 'sum']),
1024-
name2: Series([self.frame.loc(other_axis)[name2].sum(),
1043+
index=['mean', 'sum'])),
1044+
(name2, Series([self.frame.loc(other_axis)[name2].sum(),
10251045
self.frame.loc(other_axis)[name2].max()],
1026-
index=['sum', 'max'])})
1027-
assert_frame_equal(result.reindex_like(expected), expected)
1046+
index=['sum', 'max'])),
1047+
]))
1048+
expected = expected.T if axis == 1 else expected
1049+
assert_frame_equal(result, expected)
10281050

10291051
def test_nuiscance_columns(self):
10301052

0 commit comments

Comments
 (0)