Skip to content

Commit 203f411

Browse files
committed
BUG: GroupBy.apply bug with differently indexed MultiIndex objects, test coverage
1 parent 3b920ae commit 203f411

File tree

6 files changed

+75
-20
lines changed

6 files changed

+75
-20
lines changed

RELEASE.rst

+2
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ feedback on the library.
194194
- Can pass level name to `DataFrame.stack`
195195
- Support set operations between MultiIndex and Index
196196
- Fix many corner cases in MultiIndex set operations
197+
- Fix MultiIndex-handling bug with GroupBy.apply when returned groups are not
198+
indexed the same
197199

198200
Thanks
199201
------

pandas/core/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def __call__(self, num):
416416

417417
sign = 1
418418

419-
if dnum < 0:
419+
if dnum < 0: # pragma: no cover
420420
sign = -1
421421
dnum = -dnum
422422

@@ -439,7 +439,7 @@ def __call__(self, num):
439439

440440
mant = sign*dnum/(10**pow10)
441441

442-
if self.precision is None:
442+
if self.precision is None: # pragma: no cover
443443
format_str = u"%g%s"
444444
elif self.precision == 0:
445445
format_str = u"%i%s"

pandas/core/groupby.py

+25-17
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,13 @@ def _python_apply_general(self, func, *args, **kwargs):
412412
not_indexed_same = False
413413
for key, group in self:
414414
group.name = key
415+
416+
# group might be modified
417+
group_axes = _get_axes(group)
418+
415419
res = func(group, *args, **kwargs)
416-
if not _is_indexed_like(res, group):
420+
421+
if not _is_indexed_like(res, group_axes):
417422
not_indexed_same = True
418423

419424
result_keys.append(key)
@@ -460,18 +465,19 @@ def groupby(obj, by, **kwds):
460465
return klass(obj, by, **kwds)
461466
groupby.__doc__ = GroupBy.__doc__
462467

463-
def _is_indexed_like(obj, other):
468+
def _get_axes(group):
469+
if isinstance(group, Series):
470+
return [group.index]
471+
else:
472+
return group.axes
473+
474+
def _is_indexed_like(obj, axes):
464475
if isinstance(obj, Series):
465-
if not isinstance(other, Series):
476+
if len(axes) > 1:
466477
return False
467-
return obj.index.equals(other.index)
478+
return obj.index.equals(axes[0])
468479
elif isinstance(obj, DataFrame):
469-
if isinstance(other, Series):
470-
return obj.index.equals(other.index)
471-
472-
# deal with this when a case arises
473-
assert(isinstance(other, DataFrame))
474-
return obj._indexed_same(other)
480+
return obj.index.equals(axes[0])
475481

476482
return False
477483

@@ -1093,11 +1099,7 @@ def _concat_frames(frames, index, columns=None, axis=0):
10931099
return result.reindex(index=index, columns=columns)
10941100

10951101
def _concat_indexes(indexes):
1096-
if len(indexes) == 1:
1097-
new_index = indexes[0]
1098-
else:
1099-
new_index = indexes[0].append(indexes[1:])
1100-
return new_index
1102+
return indexes[0].append(indexes[1:])
11011103

11021104
def _concat_frames_hierarchical(frames, keys, groupings, axis=0):
11031105
if axis == 0:
@@ -1135,8 +1137,14 @@ def _make_concat_multiindex(indexes, keys, groupings):
11351137
to_concat.append(np.repeat(k, len(index)))
11361138
label_list.append(np.concatenate(to_concat))
11371139

1138-
# these go in the last level
1139-
label_list.append(np.concatenate(indexes))
1140+
concat_index = _concat_indexes(indexes)
1141+
1142+
# these go at the end
1143+
if isinstance(concat_index, MultiIndex):
1144+
for level in range(concat_index.nlevels):
1145+
label_list.append(concat_index.get_level_values(level))
1146+
else:
1147+
label_list.append(concat_index.values)
11401148

11411149
return MultiIndex.from_arrays(label_list)
11421150

pandas/tests/test_frame.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from numpy.random import randn
1212
import numpy as np
1313

14+
import pandas.core.common as common
1415
import pandas.core.datetools as datetools
1516
from pandas.core.index import NULL_INDEX
1617
from pandas.core.api import (DataFrame, Index, Series, notnull, isnull,
@@ -1245,10 +1246,26 @@ def test_repr(self):
12451246
index=np.arange(50))
12461247
foo = repr(unsortable)
12471248

1248-
import pandas.core.common as common
12491249
common.set_printoptions(precision=3, column_space=10)
12501250
repr(self.frame)
12511251

1252+
def test_eng_float_formatter(self):
1253+
self.frame.ix[5] = 0
1254+
1255+
common.set_eng_float_format()
1256+
1257+
repr(self.frame)
1258+
1259+
common.set_eng_float_format(use_eng_prefix=True)
1260+
1261+
repr(self.frame)
1262+
1263+
common.set_eng_float_format(precision=0)
1264+
1265+
repr(self.frame)
1266+
1267+
common.set_printoptions(precision=4)
1268+
12521269
def test_repr_tuples(self):
12531270
buf = StringIO()
12541271

pandas/tests/test_groupby.py

+25
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,31 @@ def f(group):
913913

914914
assert_frame_equal(result, expected)
915915

916+
def test_apply_corner(self):
917+
result = self.tsframe.groupby(lambda x: x.year).apply(lambda x: x * 2)
918+
expected = self.tsframe * 2
919+
assert_frame_equal(result, expected)
920+
921+
def test_transform_mixed_type(self):
922+
index = MultiIndex.from_arrays([[0, 0, 0, 1, 1, 1],
923+
[1, 2, 3, 1, 2, 3]])
924+
df = DataFrame({'d' : [1.,1.,1.,2.,2.,2.],
925+
'c' : np.tile(['a','b','c'], 2),
926+
'v' : np.arange(1., 7.)}, index=index)
927+
928+
def f(group):
929+
group['g'] = group['d'] * 2
930+
return group[:1]
931+
932+
grouped = df.groupby('c')
933+
result = grouped.apply(f)
934+
935+
self.assert_(result['d'].dtype == np.float64)
936+
937+
for key, group in grouped:
938+
res = f(group)
939+
assert_frame_equal(res, result.ix[key])
940+
916941
class TestPanelGroupBy(unittest.TestCase):
917942

918943
def setUp(self):

pandas/tests/test_index.py

+3
Original file line numberDiff line numberDiff line change
@@ -944,16 +944,19 @@ def test_diff(self):
944944
result = self.index - self.index
945945
expected = self.index[:0]
946946
self.assert_(result.equals(expected))
947+
self.assertEqual(result.names, self.index.names)
947948

948949
# empty difference: superset
949950
result = self.index[-3:] - self.index
950951
expected = self.index[:0]
951952
self.assert_(result.equals(expected))
953+
self.assertEqual(result.names, self.index.names)
952954

953955
# empty difference: degenerate
954956
result = self.index[:0] - self.index
955957
expected = self.index[:0]
956958
self.assert_(result.equals(expected))
959+
self.assertEqual(result.names, self.index.names)
957960

958961
# names not the same
959962
chunklet = self.index[-3:]

0 commit comments

Comments
 (0)