Skip to content

Commit cc43503

Browse files
stephenrauchjreback
authored andcommitted
BUG: GH15429 transform result of timedelta from datetime
The transform() operation needs to return a like-indexed. To facilitate this, transform starts with a copy of the original series. Then, after the computation for each group, sets the appropriate elements of the copied series equal to the result. At that point is does a type comparison, and discovers that the timedelta is not cast-able to a datetime.
1 parent fb7dc7d commit cc43503

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ Bug Fixes
626626

627627

628628
- Bug in ``.read_csv()`` with ``parse_dates`` when multiline headers are specified (:issue:`15376`)
629+
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`)
629630

630631

631632
- Bug in ``DataFrame.boxplot`` where ``fontsize`` was not applied to the tick labels on both axes (:issue:`15108`)

pandas/core/groupby.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_ensure_object,
3232
_ensure_categorical,
3333
_ensure_float)
34-
from pandas.types.cast import _possibly_downcast_to_dtype
34+
from pandas.types.cast import _possibly_downcast_to_dtype, _find_common_type
3535
from pandas.types.missing import isnull, notnull, _maybe_fill
3636

3737
from pandas.core.common import (_values_from_object, AbstractMethodError,
@@ -2906,8 +2906,15 @@ def transform(self, func, *args, **kwargs):
29062906
common_type = np.common_type(np.array(res), result)
29072907
if common_type != result.dtype:
29082908
result = result.astype(common_type)
2909-
except:
2910-
pass
2909+
except Exception as exc:
2910+
# date math can cause type of result to change
2911+
if i == 0 and (is_datetime64_dtype(result.dtype) or
2912+
is_timedelta64_dtype(result.dtype)):
2913+
try:
2914+
dtype = res.dtype
2915+
except Exception as exc:
2916+
dtype = type(res)
2917+
result = np.empty_like(result, dtype)
29112918

29122919
indexer = self._get_index(name)
29132920
result[indexer] = res

pandas/tests/groupby/test_transform.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pandas as pd
55
from pandas.util import testing as tm
6-
from pandas import Series, DataFrame, Timestamp, MultiIndex, concat
6+
from pandas import Series, DataFrame, Timestamp, MultiIndex, concat, date_range
77
from pandas.types.common import _ensure_platform_int
88
from .common import MixIn, assert_fp_equal
99

@@ -190,6 +190,43 @@ def test_transform_bug(self):
190190
expected = Series(np.arange(5, 0, step=-1), name='B')
191191
assert_series_equal(result, expected)
192192

193+
def test_transform_datetime_to_timedelta(self):
194+
# GH 15429
195+
# transforming a datetime to timedelta
196+
df = DataFrame(dict(A=Timestamp('20130101'), B=np.arange(5)))
197+
expected = pd.Series([
198+
Timestamp('20130101') - Timestamp('20130101')] * 5, name='A')
199+
200+
# this does date math without changing result type in transform
201+
base_time = df['A'][0]
202+
result = df.groupby('A')['A'].transform(
203+
lambda x: x.max() - x.min() + base_time) - base_time
204+
assert_series_equal(result, expected)
205+
206+
# this does date math and causes the transform to return timedelta
207+
result = df.groupby('A')['A'].transform(lambda x: x.max() - x.min())
208+
assert_series_equal(result, expected)
209+
210+
def test_transform_datetime_to_numeric(self):
211+
# GH 10972
212+
# convert dt to float
213+
df = DataFrame({
214+
'a': 1, 'b': date_range('2015-01-01', periods=2, freq='D')})
215+
result = df.groupby('a').b.transform(
216+
lambda x: x.dt.dayofweek - x.dt.dayofweek.mean())
217+
218+
expected = Series([-0.5, 0.5], name='b')
219+
assert_series_equal(result, expected)
220+
221+
# convert dt to int
222+
df = DataFrame({
223+
'a': 1, 'b': date_range('2015-01-01', periods=2, freq='D')})
224+
result = df.groupby('a').b.transform(
225+
lambda x: x.dt.dayofweek - x.dt.dayofweek.min())
226+
227+
expected = Series([0, 1], name='b')
228+
assert_series_equal(result, expected)
229+
193230
def test_transform_multiple(self):
194231
grouped = self.ts.groupby([lambda x: x.year, lambda x: x.month])
195232

0 commit comments

Comments
 (0)