Skip to content

Commit 2af5806

Browse files
committed
KLUDGE: check for iNaT in integer data prior to accumulate/transform in groupby
xref pandas-dev#15053
1 parent 0fe491d commit 2af5806

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

pandas/core/groupby.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
is_bool_dtype,
2525
is_scalar,
2626
is_list_like,
27+
needs_i8_conversion,
2728
_ensure_float64,
2829
_ensure_platform_int,
2930
_ensure_int64,
@@ -1844,15 +1845,21 @@ def _cython_operation(self, kind, values, how, axis):
18441845
"supported for the 'how' argument")
18451846
out_shape = (self.ngroups,) + values.shape[1:]
18461847

1848+
is_datetimelike = needs_i8_conversion(values.dtype)
18471849
is_numeric = is_numeric_dtype(values.dtype)
18481850

1849-
if is_datetime_or_timedelta_dtype(values.dtype):
1851+
if is_datetimelike:
18501852
values = values.view('int64')
18511853
is_numeric = True
18521854
elif is_bool_dtype(values.dtype):
18531855
values = _ensure_float64(values)
18541856
elif is_integer_dtype(values):
1855-
values = values.astype('int64', copy=False)
1857+
# we use iNaT for the missing value on ints
1858+
# so pre-convert to guard this condition
1859+
if (values == tslib.iNaT).any():
1860+
values = _ensure_float64(values)
1861+
else:
1862+
values = values.astype('int64', copy=False)
18561863
elif is_numeric and not is_complex_dtype(values):
18571864
values = _ensure_float64(values)
18581865
else:
@@ -1881,20 +1888,17 @@ def _cython_operation(self, kind, values, how, axis):
18811888
fill_value=np.nan)
18821889
counts = np.zeros(self.ngroups, dtype=np.int64)
18831890
result = self._aggregate(
1884-
result, counts, values, labels, func, is_numeric)
1891+
result, counts, values, labels, func,
1892+
is_numeric, is_datetimelike)
18851893
elif kind == 'transform':
18861894
result = _maybe_fill(np.empty_like(values, dtype=out_dtype),
18871895
fill_value=np.nan)
18881896

18891897
# temporary storange for running-total type tranforms
18901898
accum = np.empty(out_shape, dtype=out_dtype)
18911899
result = self._transform(
1892-
result, accum, values, labels, func, is_numeric)
1893-
1894-
if is_integer_dtype(result):
1895-
if len(result[result == tslib.iNaT]) > 0:
1896-
result = result.astype('float64')
1897-
result[result == tslib.iNaT] = np.nan
1900+
result, accum, values, labels, func,
1901+
is_numeric, is_datetimelike)
18981902

18991903
if kind == 'aggregate' and \
19001904
self._filter_empty_groups and not counts.all():
@@ -1929,8 +1933,19 @@ def aggregate(self, values, how, axis=0):
19291933
def transform(self, values, how, axis=0):
19301934
return self._cython_operation('transform', values, how, axis)
19311935

1936+
def _maybe_mask_missing(self, result, is_datetimelike):
1937+
# we use iNaT as a marker for missing values
1938+
# but we *only* care for non-datetimelikes
1939+
if is_integer_dtype(result) and not is_datetimelike:
1940+
mask = result == tslib.iNaT
1941+
if mask.any():
1942+
result = result.astype('float64')
1943+
result[mask] = np.nan
1944+
return result
1945+
19321946
def _aggregate(self, result, counts, values, comp_ids, agg_func,
1933-
is_numeric):
1947+
is_numeric, is_datetimelike):
1948+
19341949
if values.ndim > 3:
19351950
# punting for now
19361951
raise NotImplementedError("number of dimensions is currently "
@@ -1943,11 +1958,12 @@ def _aggregate(self, result, counts, values, comp_ids, agg_func,
19431958
else:
19441959
agg_func(result, counts, values, comp_ids)
19451960

1946-
return result
1961+
return self._maybe_mask_missing(result, is_datetimelike)
19471962

19481963
def _transform(self, result, accum, values, comp_ids, transform_func,
1949-
is_numeric):
1964+
is_numeric, is_datetimelike):
19501965
comp_ids, _, ngroups = self.group_info
1966+
19511967
if values.ndim > 3:
19521968
# punting for now
19531969
raise NotImplementedError("number of dimensions is currently "
@@ -1961,7 +1977,7 @@ def _transform(self, result, accum, values, comp_ids, transform_func,
19611977
else:
19621978
transform_func(result, values, comp_ids, accum)
19631979

1964-
return result
1980+
return self._maybe_mask_missing(result, is_datetimelike)
19651981

19661982
def agg_series(self, obj, func):
19671983
try:

0 commit comments

Comments
 (0)