Skip to content

Commit 08a5523

Browse files
committed
ENH: test coverage, made pivot_table work with no rows passed and margins=True
1 parent 3fae25e commit 08a5523

File tree

7 files changed

+99
-77
lines changed

7 files changed

+99
-77
lines changed

doc/source/io.rst

+7-6
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,16 @@ In a current or later Python session, you can retrieve stored objects:
375375
376376
store['df']
377377
378-
Storing in Table format
379-
~~~~~~~~~~~~~~~~~~~~~~~
380-
381-
Querying objects stored in Table format
382-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
383-
384378
.. ipython:: python
385379
:suppress:
386380
387381
store.close()
388382
import os
389383
os.remove('store.h5')
384+
385+
386+
.. Storing in Table format
387+
.. ~~~~~~~~~~~~~~~~~~~~~~~
388+
389+
.. Querying objects stored in Table format
390+
.. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

pandas/core/frame.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4000,7 +4000,8 @@ def _lexsort_indexer(keys):
40004000
shape.append(len(rizer.uniques))
40014001

40024002
group_index = get_group_index(labels, shape)
4003-
comp_ids, _, max_group = _compress_group_index(group_index)
4003+
comp_ids, obs_ids = _compress_group_index(group_index)
4004+
max_group = len(obs_ids)
40044005
indexer, _ = lib.groupsort_indexer(comp_ids.astype('i4'), max_group)
40054006
return indexer
40064007

pandas/core/groupby.py

+36-45
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,13 @@ def __iter__(self):
223223
def _multi_iter(self):
224224
data = self.obj
225225
group_index = self._group_index
226-
comp_ids, _, ngroups = _compress_group_index(group_index)
226+
comp_ids, obs_ids = _compress_group_index(group_index)
227+
ngroups = len(obs_ids)
227228
label_list = [ping.labels for ping in self.groupings]
228229
level_list = [ping.group_index for ping in self.groupings]
229230
mapper = _KeyMapper(comp_ids, ngroups, label_list, level_list)
230231

231232
for label, group in self._generate_groups(data, comp_ids, ngroups):
232-
if group is None:
233-
continue
234233
key = mapper.get_key(label)
235234
yield key, group
236235

@@ -335,7 +334,8 @@ def _cython_agg_general(self, how):
335334
# aggregate all the columns at once?)
336335

337336
group_index = self._group_index
338-
comp_ids, obs_group_ids, max_group = _compress_group_index(group_index)
337+
comp_ids, obs_group_ids = _compress_group_index(group_index)
338+
max_group = len(obs_group_ids)
339339

340340
output = {}
341341
for name, obj in self._iterate_slices():
@@ -355,6 +355,32 @@ def _cython_agg_general(self, how):
355355

356356
return self._wrap_aggregated_output(output, mask, obs_group_ids)
357357

358+
def _python_agg_general(self, func, *args, **kwargs):
359+
agg_func = lambda x: func(x, *args, **kwargs)
360+
361+
group_index = self._group_index
362+
comp_ids, obs_group_ids = _compress_group_index(group_index)
363+
max_group = len(obs_group_ids)
364+
365+
# iterate through "columns" ex exclusions to populate output dict
366+
output = {}
367+
for name, obj in self._iterate_slices():
368+
try:
369+
result, counts = self._aggregate_series(obj, agg_func,
370+
comp_ids, max_group)
371+
output[name] = result
372+
except TypeError:
373+
continue
374+
375+
if len(output) == 0:
376+
return self._python_apply_general(func, *args, **kwargs)
377+
378+
mask = counts.ravel() > 0
379+
for name, result in output.iteritems():
380+
output[name] = result[mask]
381+
382+
return self._wrap_aggregated_output(output, mask, obs_group_ids)
383+
358384
@property
359385
def _group_index(self):
360386
result = get_group_index([ping.labels for ping in self.groupings],
@@ -380,31 +406,6 @@ def _get_group_levels(self, mask, obs_ids):
380406

381407
return name_list
382408

383-
def _python_agg_general(self, func, *args, **kwargs):
384-
agg_func = lambda x: func(x, *args, **kwargs)
385-
386-
group_index = self._group_index
387-
comp_ids, obs_group_ids, max_group = _compress_group_index(group_index)
388-
389-
# iterate through "columns" ex exclusions to populate output dict
390-
output = {}
391-
for name, obj in self._iterate_slices():
392-
try:
393-
result, counts = self._aggregate_series(obj, agg_func,
394-
comp_ids, max_group)
395-
output[name] = result
396-
except TypeError:
397-
continue
398-
399-
if len(output) == 0:
400-
return self._python_apply_general(func, *args, **kwargs)
401-
402-
mask = counts.ravel() > 0
403-
for name, result in output.iteritems():
404-
output[name] = result[mask]
405-
406-
return self._wrap_aggregated_output(output, mask, obs_group_ids)
407-
408409
def _aggregate_series(self, obj, func, group_index, ngroups):
409410
try:
410411
return self._aggregate_series_fast(obj, func, group_index, ngroups)
@@ -431,8 +432,6 @@ def _aggregate_series_pure_python(self, obj, func, group_index, ngroups):
431432
result = None
432433

433434
for label, group in self._generate_groups(obj, group_index, ngroups):
434-
if group is None:
435-
continue
436435
res = func(group)
437436
if result is None:
438437
try:
@@ -597,7 +596,6 @@ def __iter__(self):
597596
return iter(self.indices)
598597

599598
_labels = None
600-
_ids = None
601599
_counts = None
602600
_group_index = None
603601

@@ -615,13 +613,6 @@ def labels(self):
615613
self._make_labels()
616614
return self._labels
617615

618-
@property
619-
def ids(self):
620-
if self._ids is None:
621-
index = self.group_index
622-
self._ids = dict(zip(range(len(index)), index))
623-
return self._ids
624-
625616
@property
626617
def counts(self):
627618
if self._counts is None:
@@ -1297,10 +1288,11 @@ def _get_slice(slob):
12971288
ngroups)
12981289

12991290
for i, (start, end) in enumerate(zip(starts, ends)):
1300-
if start == end:
1301-
yield i, None
1302-
else:
1303-
yield i, _get_slice(slice(start, end))
1291+
# Since I'm now compressing the group ids, it's now not "possible" to
1292+
# produce empty slices because such groups would not be observed in the
1293+
# data
1294+
assert(start < end)
1295+
yield i, _get_slice(slice(start, end))
13041296

13051297
def get_group_index(label_list, shape):
13061298
if len(label_list) == 1:
@@ -1390,7 +1382,6 @@ def _compress_group_index(group_index, sort=True):
13901382

13911383
group_index = _ensure_int64(group_index)
13921384
comp_ids = table.get_labels_groupby(group_index, uniques)
1393-
max_group = len(uniques)
13941385

13951386
# these are the ones we observed
13961387
obs_group_ids = np.array(uniques, dtype='i8')
@@ -1406,7 +1397,7 @@ def _compress_group_index(group_index, sort=True):
14061397

14071398
obs_group_ids = obs_group_ids.take(sorter)
14081399

1409-
return comp_ids, obs_group_ids, max_group
1400+
return comp_ids, obs_group_ids
14101401

14111402
def _groupby_indices(values):
14121403
if values.dtype != np.object_:

pandas/tests/test_daterange.py

+5
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def test_getitem(self):
107107
# 32-bit vs. 64-bit platforms
108108
self.assertEquals(self.rng[4], self.rng[np.int_(4)])
109109

110+
def test_getitem_matplotlib_hackaround(self):
111+
values = self.rng[:, None]
112+
expected = self.rng.values[:, None]
113+
self.assert_(np.array_equal(values, expected))
114+
110115
def test_shift(self):
111116
shifted = self.rng.shift(5)
112117
self.assertEquals(shifted[0], self.rng[5])

pandas/tests/test_series.py

+4
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def _check_op(arr, op):
11311131

11321132
def test_series_frame_radd_bug(self):
11331133
from pandas.util.testing import rands
1134+
import operator
11341135

11351136
# GH 353
11361137
vals = Series([rands(5) for _ in xrange(10)])
@@ -1143,6 +1144,9 @@ def test_series_frame_radd_bug(self):
11431144
expected = DataFrame({'vals' : vals.map(lambda x: 'foo_' + x)})
11441145
tm.assert_frame_equal(result, expected)
11451146

1147+
# really raise this time
1148+
self.assertRaises(TypeError, operator.add, datetime.now(), self.ts)
1149+
11461150
def test_operators_frame(self):
11471151
# rpow does not work with DataFrame
11481152
df = DataFrame({'A' : self.ts})

pandas/tools/pivot.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -108,23 +108,6 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
108108
DataFrame.pivot_table = pivot_table
109109

110110
def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
111-
if len(cols) > 0:
112-
col_margin = data[rows + values].groupby(rows).agg(aggfunc)
113-
114-
# need to "interleave" the margins
115-
table_pieces = []
116-
margin_keys = []
117-
for key, piece in table.groupby(level=0, axis=1):
118-
all_key = (key, 'All') + ('',) * (len(cols) - 1)
119-
piece[all_key] = col_margin[key]
120-
table_pieces.append(piece)
121-
margin_keys.append(all_key)
122-
123-
result = concat(table_pieces, axis=1)
124-
else:
125-
result = table
126-
margin_keys = table.columns
127-
128111
grand_margin = {}
129112
for k, v in data[values].iteritems():
130113
try:
@@ -135,6 +118,40 @@ def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
135118
except TypeError:
136119
pass
137120

121+
if len(cols) > 0:
122+
# need to "interleave" the margins
123+
table_pieces = []
124+
margin_keys = []
125+
126+
127+
def _all_key(key):
128+
return (key, 'All') + ('',) * (len(cols) - 1)
129+
130+
if len(rows) > 0:
131+
margin = data[rows + values].groupby(rows).agg(aggfunc)
132+
cat_axis = 1
133+
for key, piece in table.groupby(level=0, axis=cat_axis):
134+
all_key = _all_key(key)
135+
piece[all_key] = margin[key]
136+
table_pieces.append(piece)
137+
margin_keys.append(all_key)
138+
else:
139+
margin = grand_margin
140+
cat_axis = 0
141+
for key, piece in table.groupby(level=0, axis=cat_axis):
142+
all_key = _all_key(key)
143+
table_pieces.append(piece)
144+
table_pieces.append(Series(margin[key], index=[all_key]))
145+
margin_keys.append(all_key)
146+
147+
result = concat(table_pieces, axis=cat_axis)
148+
149+
if len(rows) == 0:
150+
return result
151+
else:
152+
result = table
153+
margin_keys = table.columns
154+
138155
if len(cols) > 0:
139156
row_margin = data[cols + values].groupby(cols).agg(aggfunc)
140157
row_margin = row_margin.stack()

pandas/tools/tests/test_pivot.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,17 @@ def _check_output(res, col, rows=['A', 'B'], cols=['C']):
116116
gmarg = table[valcol]['All', '']
117117
self.assertEqual(gmarg, self.data[valcol].mean())
118118

119-
# doesn't quite work yet
120-
121-
# # no rows
122-
# table = self.data.pivot_table(cols=['A', 'B'], margins=True,
123-
# aggfunc=np.mean)
124-
# for valcol in table.columns:
125-
# gmarg = table[valcol]['All', '']
126-
# self.assertEqual(gmarg, self.data[valcol].mean())
119+
# this is OK
120+
table = self.data.pivot_table(rows=['AA', 'BB'], margins=True,
121+
aggfunc='mean')
122+
123+
# no rows
124+
rtable = self.data.pivot_table(cols=['AA', 'BB'], margins=True,
125+
aggfunc=np.mean)
126+
self.assert_(isinstance(rtable, Series))
127+
for item in ['DD', 'EE', 'FF']:
128+
gmarg = table[item]['All', '']
129+
self.assertEqual(gmarg, self.data[item].mean())
127130

128131

129132
class TestCrosstab(unittest.TestCase):

0 commit comments

Comments
 (0)