Skip to content

Commit 61fa8be

Browse files
committed
BUG: fix groupby.aggregate resulting dtype coercion, xref pandas-dev#11444, pandas-dev#13046
make sure .size includes the name of the grouped
1 parent 251826f commit 61fa8be

File tree

5 files changed

+72
-13
lines changed

5 files changed

+72
-13
lines changed

doc/source/whatsnew/v0.20.0.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -632,12 +632,12 @@ Bug Fixes
632632

633633
- Bug in ``DataFrame.to_stata()`` and ``StataWriter`` which produces incorrectly formatted files to be produced for some locales (:issue:`13856`)
634634
- Bug in ``pd.concat()`` in which concatting with an empty dataframe with ``join='inner'`` was being improperly handled (:issue:`15328`)
635-
- Bug in ``groupby.agg()`` incorrectly localizing timezone on ``datetime`` (:issue:`15426`, :issue:`10668`)
635+
- Bug in ``groupby.agg()`` incorrectly localizing timezone on ``datetime`` (:issue:`15426`, :issue:`10668`, :issue:`13046`)
636636

637637

638638

639639
- Bug in ``.read_csv()`` with ``parse_dates`` when multiline headers are specified (:issue:`15376`)
640-
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`)
640+
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`, :issue:`11444`)
641641

642642
- Bug in ``DataFrame.hist`` where ``plt.tight_layout`` caused an ``AttributeError`` (use ``matplotlib >= 0.2.0``) (:issue:`9351`)
643643
- Bug in ``DataFrame.boxplot`` where ``fontsize`` was not applied to the tick labels on both axes (:issue:`15108`)

pandas/core/groupby.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,23 @@ def _index_with_as_index(self, b):
767767
new.names = gp.names + original.names
768768
return new
769769

770-
def _try_cast(self, result, obj):
770+
def _try_cast(self, result, obj, numeric_only=False):
771771
"""
772772
try to cast the result to our obj original type,
773773
we may have roundtripped thru object in the mean-time
774774
775+
if numeric_only is True, then only try to cast numerics
776+
and not datetimelikes
777+
775778
"""
776779
if obj.ndim > 1:
777780
dtype = obj.values.dtype
778781
else:
779782
dtype = obj.dtype
780783

781784
if not is_scalar(result):
782-
result = _possibly_downcast_to_dtype(result, dtype)
785+
if numeric_only and is_numeric_dtype(dtype) or not numeric_only:
786+
result = _possibly_downcast_to_dtype(result, dtype)
783787

784788
return result
785789

@@ -830,7 +834,7 @@ def _python_agg_general(self, func, *args, **kwargs):
830834
for name, obj in self._iterate_slices():
831835
try:
832836
result, counts = self.grouper.agg_series(obj, f)
833-
output[name] = self._try_cast(result, obj)
837+
output[name] = self._try_cast(result, obj, numeric_only=True)
834838
except TypeError:
835839
continue
836840

@@ -1117,7 +1121,11 @@ def sem(self, ddof=1):
11171121
@Appender(_doc_template)
11181122
def size(self):
11191123
"""Compute group sizes"""
1120-
return self.grouper.size()
1124+
result = self.grouper.size()
1125+
1126+
if isinstance(self.obj, Series):
1127+
result.name = getattr(self, 'name', None)
1128+
return result
11211129

11221130
sum = _groupby_function('sum', 'add', np.sum)
11231131
prod = _groupby_function('prod', 'prod', np.prod)
@@ -1689,7 +1697,9 @@ def size(self):
16891697
ids, _, ngroup = self.group_info
16901698
ids = _ensure_platform_int(ids)
16911699
out = np.bincount(ids[ids != -1], minlength=ngroup or None)
1692-
return Series(out, index=self.result_index, dtype='int64')
1700+
return Series(out,
1701+
index=self.result_index,
1702+
dtype='int64')
16931703

16941704
@cache_readonly
16951705
def _max_groupsize(self):
@@ -2908,7 +2918,8 @@ def transform(self, func, *args, **kwargs):
29082918
result = concat(results).sort_index()
29092919

29102920
# we will only try to coerce the result type if
2911-
# we have a numeric dtype
2921+
# we have a numeric dtype, as these are *always* udfs
2922+
# the cython take a different path (and casting)
29122923
dtype = self._selected_obj.dtype
29132924
if is_numeric_dtype(dtype):
29142925
result = _possibly_downcast_to_dtype(result, dtype)

pandas/tests/groupby/test_aggregate.py

+23
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,29 @@ def test_agg_dict_parameter_cast_result_dtypes(self):
154154
assert_series_equal(grouped.time.last(), exp['time'])
155155
assert_series_equal(grouped.time.agg('last'), exp['time'])
156156

157+
# count
158+
exp = pd.Series([2, 2, 2, 2],
159+
index=Index(list('ABCD'), name='class'),
160+
name='time')
161+
assert_series_equal(grouped.time.agg(len), exp)
162+
assert_series_equal(grouped.time.size(), exp)
163+
164+
exp = pd.Series([0, 1, 1, 2],
165+
index=Index(list('ABCD'), name='class'),
166+
name='time')
167+
assert_series_equal(grouped.time.count(), exp)
168+
169+
def test_agg_cast_results_dtypes(self):
170+
# similar to GH12821
171+
# xref #11444
172+
u = [datetime(2015, x + 1, 1) for x in range(12)]
173+
v = list('aaabbbbbbccd')
174+
df = pd.DataFrame({'X': v, 'Y': u})
175+
176+
result = df.groupby('X')['Y'].agg(len)
177+
expected = df.groupby('X')['Y'].count()
178+
assert_series_equal(result, expected)
179+
157180
def test_agg_must_agg(self):
158181
grouped = self.df.groupby('A')['C']
159182
self.assertRaises(Exception, grouped.agg, lambda x: x.describe())

pandas/tests/groupby/test_transform.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pandas as pd
55
from pandas.util import testing as tm
66
from pandas import Series, DataFrame, Timestamp, MultiIndex, concat, date_range
7-
from pandas.types.common import _ensure_platform_int
7+
from pandas.types.common import _ensure_platform_int, is_timedelta64_dtype
8+
from pandas.compat import StringIO
89
from .common import MixIn, assert_fp_equal
910

1011
from pandas.util.testing import assert_frame_equal, assert_series_equal
@@ -227,6 +228,32 @@ def test_transform_datetime_to_numeric(self):
227228
expected = Series([0, 1], name='b')
228229
assert_series_equal(result, expected)
229230

231+
def test_transform_casting(self):
232+
# 13046
233+
data = """
234+
idx A ID3 DATETIME
235+
0 B-028 b76cd912ff "2014-10-08 13:43:27"
236+
1 B-054 4a57ed0b02 "2014-10-08 14:26:19"
237+
2 B-076 1a682034f8 "2014-10-08 14:29:01"
238+
3 B-023 b76cd912ff "2014-10-08 18:39:34"
239+
4 B-023 f88g8d7sds "2014-10-08 18:40:18"
240+
5 B-033 b76cd912ff "2014-10-08 18:44:30"
241+
6 B-032 b76cd912ff "2014-10-08 18:46:00"
242+
7 B-037 b76cd912ff "2014-10-08 18:52:15"
243+
8 B-046 db959faf02 "2014-10-08 18:59:59"
244+
9 B-053 b76cd912ff "2014-10-08 19:17:48"
245+
10 B-065 b76cd912ff "2014-10-08 19:21:38"
246+
"""
247+
df = pd.read_csv(StringIO(data), sep='\s+',
248+
index_col=[0], parse_dates=['DATETIME'])
249+
250+
result = df.groupby('ID3')['DATETIME'].transform(lambda x: x.diff())
251+
assert is_timedelta64_dtype(result.dtype)
252+
253+
result = df[['ID3', 'DATETIME']].groupby('ID3').transform(
254+
lambda x: x.diff())
255+
assert is_timedelta64_dtype(result.DATETIME.dtype)
256+
230257
def test_transform_multiple(self):
231258
grouped = self.ts.groupby([lambda x: x.year, lambda x: x.month])
232259

pandas/tests/tseries/test_resample.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -757,10 +757,8 @@ def test_resample_empty_series(self):
757757
freq in ['M', 'D']):
758758
# GH12871 - TODO: name should propagate, but currently
759759
# doesn't on lower / same frequency with PeriodIndex
760-
assert_series_equal(result, expected, check_dtype=False,
761-
check_names=False)
762-
# this assert will break when fixed
763-
self.assertTrue(result.name is None)
760+
assert_series_equal(result, expected, check_dtype=False)
761+
764762
else:
765763
assert_series_equal(result, expected, check_dtype=False)
766764

0 commit comments

Comments
 (0)