Skip to content

Commit 78eb2b9

Browse files
committed
Merge pull request #3152 from jreback/GH2763
BUG: GH2763 fixed downcasting of groupby results on SeriesGroupBy
2 parents b41dc91 + 2eda888 commit 78eb2b9

File tree

4 files changed

+49
-43
lines changed

4 files changed

+49
-43
lines changed

RELEASE.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ pandas 0.11.0
198198
an irrecoverable state (GH3010_)
199199
- Bug in DataFrame update, combine_first where non-specified values could cause
200200
dtype changes (GH3016_, GH3041_)
201-
- Bug in groupby with first/last where dtypes could change (GH3041_)
201+
- Bug in groupby with first/last where dtypes could change (GH3041_, GH2763_)
202202
- Formatting of an index that has ``nan`` was inconsistent or wrong (would fill from
203203
other values), (GH2850_)
204204
- Unstack of a frame with no nans would always cause dtype upcasting (GH2929_)
@@ -251,6 +251,7 @@ pandas 0.11.0
251251
.. _GH2746: https://github.com/pydata/pandas/issues/2746
252252
.. _GH2747: https://github.com/pydata/pandas/issues/2747
253253
.. _GH2751: https://github.com/pydata/pandas/issues/2751
254+
.. _GH2763: https://github.com/pydata/pandas/issues/2763
254255
.. _GH2776: https://github.com/pydata/pandas/issues/2776
255256
.. _GH2778: https://github.com/pydata/pandas/issues/2778
256257
.. _GH2787: https://github.com/pydata/pandas/issues/2787

pandas/core/groupby.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pandas.util.compat import OrderedDict
1414
import pandas.core.algorithms as algos
1515
import pandas.core.common as com
16+
from pandas.core.common import _possibly_downcast_to_dtype
1617

1718
import pandas.lib as lib
1819
import pandas.algos as _algos
@@ -440,14 +441,7 @@ def _try_cast(self, result, obj):
440441

441442
# need to respect a non-number here (e.g. Decimal)
442443
if len(result) and issubclass(type(result[0]),(np.number,float,int)):
443-
if issubclass(dtype.type, (np.integer, np.bool_)):
444-
445-
# castable back to an int/bool as we don't have nans
446-
if com.notnull(result).all():
447-
result = result.astype(dtype)
448-
else:
449-
450-
result = result.astype(dtype)
444+
result = _possibly_downcast_to_dtype(result, dtype)
451445

452446
elif issubclass(dtype.type, np.datetime64):
453447
if is_datetime64_dtype(obj.dtype):
@@ -468,7 +462,7 @@ def _cython_agg_general(self, how, numeric_only=True):
468462
result, names = self.grouper.aggregate(obj.values, how)
469463
except AssertionError as e:
470464
raise GroupByError(str(e))
471-
output[name] = result
465+
output[name] = self._try_cast(result, obj)
472466

473467
if len(output) == 0:
474468
raise DataError('No numeric types to aggregate')

pandas/tests/test_groupby.py

+42-31
Original file line numberDiff line numberDiff line change
@@ -91,48 +91,51 @@ def setUp(self):
9191
'F': np.random.randn(11)})
9292

9393
def test_basic(self):
94-
data = Series(np.arange(9) // 3, index=np.arange(9))
9594

96-
index = np.arange(9)
97-
np.random.shuffle(index)
98-
data = data.reindex(index)
95+
def checkit(dtype):
96+
data = Series(np.arange(9) // 3, index=np.arange(9), dtype=dtype)
9997

100-
grouped = data.groupby(lambda x: x // 3)
98+
index = np.arange(9)
99+
np.random.shuffle(index)
100+
data = data.reindex(index)
101101

102-
for k, v in grouped:
103-
self.assertEqual(len(v), 3)
102+
grouped = data.groupby(lambda x: x // 3)
104103

105-
agged = grouped.aggregate(np.mean)
106-
self.assertEqual(agged[1], 1)
104+
for k, v in grouped:
105+
self.assertEqual(len(v), 3)
107106

108-
assert_series_equal(agged, grouped.agg(np.mean)) # shorthand
109-
assert_series_equal(agged, grouped.mean())
107+
agged = grouped.aggregate(np.mean)
108+
self.assertEqual(agged[1], 1)
110109

111-
# Cython only returning floating point for now...
112-
assert_series_equal(grouped.agg(np.sum).astype(float),
113-
grouped.sum())
110+
assert_series_equal(agged, grouped.agg(np.mean)) # shorthand
111+
assert_series_equal(agged, grouped.mean())
112+
assert_series_equal(grouped.agg(np.sum),grouped.sum())
114113

115-
transformed = grouped.transform(lambda x: x * x.sum())
116-
self.assertEqual(transformed[7], 12)
114+
transformed = grouped.transform(lambda x: x * x.sum())
115+
self.assertEqual(transformed[7], 12)
117116

118-
value_grouped = data.groupby(data)
119-
assert_series_equal(value_grouped.aggregate(np.mean), agged)
117+
value_grouped = data.groupby(data)
118+
assert_series_equal(value_grouped.aggregate(np.mean), agged)
120119

121-
# complex agg
122-
agged = grouped.aggregate([np.mean, np.std])
123-
agged = grouped.aggregate({'one': np.mean,
124-
'two': np.std})
120+
# complex agg
121+
agged = grouped.aggregate([np.mean, np.std])
122+
agged = grouped.aggregate({'one': np.mean,
123+
'two': np.std})
124+
125+
group_constants = {
126+
0: 10,
127+
1: 20,
128+
2: 30
129+
}
130+
agged = grouped.agg(lambda x: group_constants[x.name] + x.mean())
131+
self.assertEqual(agged[1], 21)
125132

126-
group_constants = {
127-
0: 10,
128-
1: 20,
129-
2: 30
130-
}
131-
agged = grouped.agg(lambda x: group_constants[x.name] + x.mean())
132-
self.assertEqual(agged[1], 21)
133+
# corner cases
134+
self.assertRaises(Exception, grouped.aggregate, lambda x: x * 2)
133135

134-
# corner cases
135-
self.assertRaises(Exception, grouped.aggregate, lambda x: x * 2)
136+
137+
for dtype in ['int64','int32','float64','float32']:
138+
checkit(dtype)
136139

137140
def test_first_last_nth(self):
138141
# tests for first / last / nth
@@ -185,6 +188,14 @@ def test_first_last_nth_dtypes(self):
185188
expected.index = ['bar', 'foo']
186189
assert_frame_equal(nth, expected, check_names=False)
187190

191+
# GH 2763, first/last shifting dtypes
192+
idx = range(10)
193+
idx.append(9)
194+
s = Series(data=range(11), index=idx, name='IntCol')
195+
self.assert_(s.dtype == 'int64')
196+
f = s.groupby(level=0).first()
197+
self.assert_(f.dtype == 'int64')
198+
188199
def test_grouper_iter(self):
189200
self.assertEqual(sorted(self.df.groupby('A').grouper), ['bar', 'foo'])
190201

pandas/tseries/tests/test_resample.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def test_custom_grouper(self):
7171
idx = idx.append(dti[-1:])
7272
expect = Series(arr, index=idx)
7373

74-
# cython returns float for now
74+
# GH2763 - return in put dtype if we can
7575
result = g.agg(np.sum)
76-
assert_series_equal(result, expect.astype(float))
76+
assert_series_equal(result, expect)
7777

7878
data = np.random.rand(len(dti), 10)
7979
df = DataFrame(data, index=dti)

0 commit comments

Comments
 (0)