Skip to content

Commit 4ed2e42

Browse files
committed
fixed DataFrame.toCSV bug, test coverage
1 parent 8f0a6c5 commit 4ed2e42

File tree

5 files changed

+56
-24
lines changed

5 files changed

+56
-24
lines changed

pandas/core/frame.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -738,13 +738,18 @@ def toCSV(self, path, nanRep='', cols=None, header=True,
738738
for idx in self.index:
739739
if index:
740740
f.write(str(idx))
741-
for col in cols:
741+
for i, col in enumerate(cols):
742742
val = series[col].get(idx)
743743
if isnull(val):
744744
val = nanRep
745745
else:
746746
val = str(val)
747-
f.write(',%s' % val)
747+
748+
if i > 0 or index:
749+
f.write(',%s' % val)
750+
else:
751+
f.write('%s' % val)
752+
748753
f.write('\n')
749754

750755
f.close()

pandas/core/groupby.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -337,44 +337,38 @@ def transform(self, func):
337337
# DataMatrix objects?
338338
result_values = np.empty_like(self.obj.values)
339339

340-
if self.axis == 1:
341-
result_values = result_values.T
340+
if self.axis == 0:
341+
trans = lambda x: x
342+
elif self.axis == 1:
343+
trans = lambda x: x.T
344+
345+
result_values = trans(result_values)
342346

343-
# result = {}
344347
for val, group in self.groups.iteritems():
345-
if not isinstance(group, list):
348+
if not isinstance(group, list): # pragma: no cover
346349
group = list(group)
347350

348-
subframe = self.obj.reindex(group)
349-
subframe.groupName = val
350-
351351
if self.axis == 0:
352+
subframe = self.obj.reindex(group)
352353
indexer, _ = common.get_indexer(self.obj.index,
353354
subframe.index, None)
354355
else:
356+
subframe = self.obj.reindex(columns=group)
355357
indexer, _ = common.get_indexer(self.obj.columns,
356358
subframe.columns, None)
359+
subframe.groupName = val
357360

358361
try:
359-
res = func(subframe)
360-
except Exception:
361362
res = subframe.apply(func, axis=self.axis)
363+
except Exception: # pragma: no cover
364+
res = func(subframe)
362365

363-
result_values[indexer] = res.values
364-
365-
# result[val] = res
366+
result_values[indexer] = trans(res.values)
366367

367-
if self.axis == 1:
368-
result_values = result_values.T
368+
result_values = trans(result_values)
369369

370370
return DataFrame(result_values, index=self.obj.index,
371371
columns=self.obj.columns)
372-
# allSeries = {}
373-
# for val, frame in result.iteritems():
374-
# allSeries.update(frame._series)
375-
376-
# return self._klass(data=allSeries).T
377-
378372

379373
class DataMatrixGroupBy(DataFrameGroupBy):
380374
_klass = DataMatrix

pandas/core/matrix.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ def info(self, buffer=sys.stdout):
791791
#-------------------------------------------------------------------------------
792792
# Public methods
793793

794-
def apply(self, func, axis=0):
794+
def apply(self, func, axis=0, broadcast=False):
795795
"""
796796
Applies func to columns (Series) of this DataMatrix and returns either
797797
a DataMatrix (if the function produces another series) or a Series
@@ -802,6 +802,9 @@ def apply(self, func, axis=0):
802802
----------
803803
func : function
804804
Function to apply to each column
805+
broadcast : bool, default False
806+
For aggregation functions, return object of same size with values
807+
propagated
805808
806809
Examples
807810
--------
@@ -819,7 +822,8 @@ def apply(self, func, axis=0):
819822
return DataMatrix(data=results, index=self.index,
820823
columns=self.columns, objects=self.objects)
821824
else:
822-
return DataFrame.apply(self, func, axis=axis)
825+
return DataFrame.apply(self, func, axis=axis,
826+
broadcast=broadcast)
823827

824828
def applymap(self, func):
825829
"""

pandas/core/tests/test_frame.py

+24
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,12 @@ def test_toCSV_fromcsv(self):
696696
recons = self.klass.fromcsv(path, index_col=None)
697697
assert(len(recons.cols()) == len(self.tsframe.cols()) + 1)
698698

699+
700+
# no index
701+
self.tsframe.toCSV(path, index=False)
702+
recons = self.klass.fromcsv(path, index_col=None)
703+
assert_almost_equal(self.tsframe.values, recons.values)
704+
699705
os.remove(path)
700706

701707
def test_toDataMatrix(self):
@@ -1173,6 +1179,19 @@ def test_apply(self):
11731179
applied = self.empty.apply(np.mean)
11741180
self.assert_(not applied)
11751181

1182+
1183+
def test_apply_broadcast(self):
1184+
broadcasted = self.frame.apply(np.mean, broadcast=True)
1185+
agged = self.frame.apply(np.mean)
1186+
1187+
for col, ts in broadcasted.iteritems():
1188+
self.assert_((ts == agged[col]).all())
1189+
1190+
broadcasted = self.frame.apply(np.mean, axis=1, broadcast=True)
1191+
agged = self.frame.apply(np.mean, axis=1)
1192+
for idx in broadcasted.index:
1193+
self.assert_((broadcasted.xs(idx) == agged[idx]).all())
1194+
11761195
def test_tapply(self):
11771196
d = self.frame.index[0]
11781197
tapplied = self.frame.tapply(np.mean)
@@ -1234,6 +1253,11 @@ def test_groupby_columns(self):
12341253
self.assertEqual(len(aggregated), len(self.tsframe))
12351254
self.assertEqual(len(aggregated.cols()), 2)
12361255

1256+
# transform
1257+
tf = lambda x: x - x.mean()
1258+
groupedT = self.tsframe.T.groupby(mapping, axis=0)
1259+
assert_frame_equal(groupedT.transform(tf).T, grouped.transform(tf))
1260+
12371261
# iterate
12381262
for k, v in grouped:
12391263
self.assertEqual(len(v.cols()), 2)

pandas/core/tests/test_series.py

+5
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,10 @@ def test_groupby(self):
825825
self.assertEqual(agged[1], 1)
826826

827827
assert_series_equal(agged, grouped.agg(np.mean)) # shorthand
828+
assert_series_equal(agged, grouped.mean())
829+
830+
assert_series_equal(grouped.agg(np.sum), grouped.sum())
831+
828832

829833
transformed = grouped.transform(lambda x: x * x.sum())
830834
self.assertEqual(transformed[7], 12)
@@ -867,5 +871,6 @@ def test_groupby_transform(self):
867871
for idx in group.index:
868872
self.assertEqual(transformed[idx], mean)
869873

874+
870875
if __name__ == '__main__':
871876
unittest.main()

0 commit comments

Comments
 (0)