Skip to content

Commit cdd78db

Browse files
committed
BUG: preserve categorical & sparse types when grouping / pivot
closes pandas-dev#18502
1 parent de0867f commit cdd78db

File tree

12 files changed

+154
-62
lines changed

12 files changed

+154
-62
lines changed

doc/source/whatsnew/v0.25.0.rst

+29
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,35 @@ of ``object`` dtype. :attr:`Series.str` will now infer the dtype data *within* t
322322
s
323323
s.str.startswith(b'a')
324324
325+
.. _whatsnew_0250.api_breaking.groupby_categorical:
326+
327+
Categorical dtypes are preserved during groupby
328+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
329+
330+
Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)
331+
332+
.. ipython:: python
333+
334+
df = pd.DataFrame(
335+
{'payload': [-1, -2, -1, -2],
336+
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
337+
df
338+
df.dtypes
339+
340+
*Previous Behavior*:
341+
342+
.. code-block:: python
343+
344+
In [5]: df.groupby('payload').first().col.dtype
345+
Out[5]: dtype('O')
346+
347+
*New Behavior*:
348+
349+
.. ipython:: python
350+
351+
df.groupby('payload').first().col.dtype
352+
353+
325354
.. _whatsnew_0250.api_breaking.incompatible_index_unions:
326355

327356
Incompatible Index type unions

pandas/core/groupby/generic.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,
158158

159159
obj = self.obj[data.items[locs]]
160160
s = groupby(obj, self.grouper)
161-
result = s.aggregate(lambda x: alt(x, axis=self.axis))
161+
try:
162+
result = s.aggregate(lambda x: alt(x, axis=self.axis))
163+
except Exception:
164+
# we may have an exception in trying to aggregate
165+
# continue and exclude the block
166+
pass
162167

163168
finally:
164169

170+
dtype = block.values.dtype
171+
165172
# see if we can cast the block back to the original dtype
166-
result = block._try_coerce_and_cast_result(result)
173+
result = block._try_coerce_and_cast_result(result, dtype=dtype)
167174
newb = block.make_block(result)
168175

169176
new_items.append(locs)

pandas/core/groupby/groupby.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
786786
elif is_extension_array_dtype(dtype):
787787
# The function can return something of any type, so check
788788
# if the type is compatible with the calling EA.
789+
790+
# return the same type (Series) as our caller
789791
try:
790792
result = obj._values._from_sequence(result, dtype=dtype)
791793
except Exception:
@@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
11571159
"""
11581160
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
11591161
try:
1160-
return self._cython_agg_general('mean', **kwargs)
1162+
return self._cython_agg_general(
1163+
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
11611164
except GroupByError:
11621165
raise
11631166
except Exception: # pragma: no cover
@@ -1179,7 +1182,11 @@ def median(self, **kwargs):
11791182
Median of values within each group.
11801183
"""
11811184
try:
1182-
return self._cython_agg_general('median', **kwargs)
1185+
return self._cython_agg_general(
1186+
'median',
1187+
alt=lambda x,
1188+
axis: Series(x).median(axis=axis, **kwargs),
1189+
**kwargs)
11831190
except GroupByError:
11841191
raise
11851192
except Exception: # pragma: no cover
@@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
12351242
nv.validate_groupby_func('var', args, kwargs)
12361243
if ddof == 1:
12371244
try:
1238-
return self._cython_agg_general('var', **kwargs)
1245+
return self._cython_agg_general(
1246+
'var',
1247+
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
1248+
**kwargs)
12391249
except Exception:
12401250
f = lambda x: x.var(ddof=ddof, **kwargs)
12411251
with _group_selection_context(self):
@@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
12631273
Series or DataFrame
12641274
Standard error of the mean of values within each group.
12651275
"""
1266-
12671276
return self.std(ddof=ddof) / np.sqrt(self.count())
12681277

12691278
@Substitution(name='groupby')
@@ -1290,7 +1299,7 @@ def _add_numeric_operations(cls):
12901299
"""
12911300

12921301
def groupby_function(name, alias, npfunc,
1293-
numeric_only=True, _convert=False,
1302+
numeric_only=True,
12941303
min_count=-1):
12951304

12961305
_local_template = """
@@ -1312,17 +1321,30 @@ def f(self, **kwargs):
13121321
kwargs['min_count'] = min_count
13131322

13141323
self._set_group_selection()
1324+
1325+
# try a cython aggregation if we can
13151326
try:
13161327
return self._cython_agg_general(
13171328
alias, alt=npfunc, **kwargs)
13181329
except AssertionError as e:
13191330
raise SpecificationError(str(e))
13201331
except Exception:
1321-
result = self.aggregate(
1322-
lambda x: npfunc(x, axis=self.axis))
1323-
if _convert:
1324-
result = result._convert(datetime=True)
1325-
return result
1332+
pass
1333+
1334+
# apply a non-cython aggregation
1335+
result = self.aggregate(
1336+
lambda x: npfunc(x, axis=self.axis))
1337+
1338+
# coerce the resulting columns if we can
1339+
if isinstance(result, DataFrame):
1340+
for col in result.columns:
1341+
result[col] = self._try_cast(
1342+
result[col], self.obj[col])
1343+
else:
1344+
result = self._try_cast(
1345+
result, self.obj)
1346+
1347+
return result
13261348

13271349
set_function_name(f, name, cls)
13281350

pandas/core/groupby/ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pandas.core.dtypes.common import (
2020
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
2121
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
22-
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
22+
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
2323
is_timedelta64_dtype, needs_i8_conversion)
2424
from pandas.core.dtypes.missing import _maybe_fill, isna
2525

@@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,
451451

452452
# categoricals are only 1d, so we
453453
# are not setup for dim transforming
454-
if is_categorical_dtype(values):
454+
if is_categorical_dtype(values) or is_sparse(values):
455455
raise NotImplementedError(
456-
"categoricals are not support in cython ops ATM")
456+
"{} are not support in cython ops".format(values.dtype))
457457
elif is_datetime64_any_dtype(values):
458458
if how in ['add', 'prod', 'cumsum', 'cumprod']:
459459
raise NotImplementedError(

pandas/core/indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pandas.core.dtypes.common import (
1212
ensure_platform_int, is_float, is_integer, is_integer_dtype, is_iterator,
13-
is_list_like, is_numeric_dtype, is_scalar, is_sequence, is_sparse)
13+
is_list_like, is_numeric_dtype, is_scalar, is_sequence)
1414
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
1515
from pandas.core.dtypes.missing import _infer_fill_value, isna
1616

pandas/core/internals/blocks.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
594594
values = self.get_values(dtype=dtype)
595595

596596
# _astype_nansafe works fine with 1-d only
597-
values = astype_nansafe(values.ravel(), dtype, copy=True)
597+
values = astype_nansafe(
598+
values.ravel(), dtype, copy=True, **kwargs)
598599

599600
# TODO(extension)
600601
# should we make this attribute?
@@ -1746,6 +1747,27 @@ def _slice(self, slicer):
17461747

17471748
return self.values[slicer]
17481749

1750+
def _try_cast_result(self, result, dtype=None):
1751+
"""
1752+
if we have an operation that operates on for example floats
1753+
we want to try to cast back to our EA here if possible
1754+
1755+
result could be a 2-D numpy array, e.g. the result of
1756+
a numeric operation; but it must be shape (1, X) because
1757+
we by-definition operate on the ExtensionBlocks one-by-one
1758+
1759+
result could also be an EA Array itself, in which case it
1760+
is already a 1-D array
1761+
"""
1762+
try:
1763+
1764+
result = self._holder._from_sequence(
1765+
np.asarray(result).ravel(), dtype=dtype)
1766+
except Exception:
1767+
pass
1768+
1769+
return result
1770+
17491771
def formatting_values(self):
17501772
# Deprecating the ability to override _formatting_values.
17511773
# Do the warning here, it's only user in pandas, since we

pandas/core/internals/construction.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,10 @@ def sanitize_array(data, index, dtype=None, copy=False,
687687
data = np.array(data, dtype=dtype, copy=False)
688688
subarr = np.array(data, dtype=object, copy=copy)
689689

690-
if is_object_dtype(subarr.dtype) and dtype != 'object':
690+
if (not (is_extension_array_dtype(subarr.dtype) or
691+
is_extension_array_dtype(dtype)) and
692+
is_object_dtype(subarr.dtype) and
693+
not is_object_dtype(dtype)):
691694
inferred = lib.infer_dtype(subarr, skipna=False)
692695
if inferred == 'period':
693696
try:

pandas/core/nanops.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ def _f(*args, **kwargs):
7272

7373
class bottleneck_switch:
7474

75-
def __init__(self, **kwargs):
75+
def __init__(self, name=None, **kwargs):
76+
self.name = name
7677
self.kwargs = kwargs
7778

7879
def __call__(self, alt):
79-
bn_name = alt.__name__
80+
bn_name = self.name or alt.__name__
8081

8182
try:
8283
bn_func = getattr(bn, bn_name)
@@ -804,7 +805,8 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):
804805

805806

806807
def _nanminmax(meth, fill_value_typ):
807-
@bottleneck_switch()
808+
809+
@bottleneck_switch(name='nan' + meth)
808810
def reduction(values, axis=None, skipna=True, mask=None):
809811

810812
values, mask, dtype, dtype_max, fill_value = _get_values(
@@ -824,7 +826,6 @@ def reduction(values, axis=None, skipna=True, mask=None):
824826
result = _wrap_results(result, dtype, fill_value)
825827
return _maybe_null_out(result, axis, mask, values.shape)
826828

827-
reduction.__name__ = 'nan' + meth
828829
return reduction
829830

830831

pandas/tests/groupby/test_function.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pandas import (
1313
DataFrame, Index, MultiIndex, Series, Timestamp, date_range, isna)
1414
import pandas.core.nanops as nanops
15-
from pandas.util import testing as tm
15+
from pandas.util import _test_decorators as td, testing as tm
1616

1717

1818
@pytest.mark.parametrize("agg_func", ['any', 'all'])
@@ -144,6 +144,7 @@ def test_arg_passthru():
144144
index=Index([1, 2], name='group'),
145145
columns=['int', 'float', 'category_int',
146146
'datetime', 'datetimetz', 'timedelta'])
147+
147148
for attr in ['mean', 'median']:
148149
f = getattr(df.groupby('group'), attr)
149150
result = f()
@@ -459,35 +460,33 @@ def test_groupby_cumprod():
459460
tm.assert_series_equal(actual, expected)
460461

461462

462-
def test_ops_general():
463-
ops = [('mean', np.mean),
464-
('median', np.median),
465-
('std', np.std),
466-
('var', np.var),
467-
('sum', np.sum),
468-
('prod', np.prod),
469-
('min', np.min),
470-
('max', np.max),
471-
('first', lambda x: x.iloc[0]),
472-
('last', lambda x: x.iloc[-1]),
473-
('count', np.size), ]
474-
try:
475-
from scipy.stats import sem
476-
except ImportError:
477-
pass
478-
else:
479-
ops.append(('sem', sem))
463+
def scipy_sem(*args, **kwargs):
464+
from scipy.stats import sem
465+
return sem(*args, ddof=1, **kwargs)
466+
467+
468+
@pytest.mark.parametrize(
469+
'op,targop',
470+
[('mean', np.mean),
471+
('median', np.median),
472+
('std', np.std),
473+
('var', np.var),
474+
('sum', np.sum),
475+
('prod', np.prod),
476+
('min', np.min),
477+
('max', np.max),
478+
('first', lambda x: x.iloc[0]),
479+
('last', lambda x: x.iloc[-1]),
480+
('count', np.size),
481+
pytest.param(
482+
'sem', scipy_sem, marks=td.skip_if_no_scipy)])
483+
def test_ops_general(op, targop):
480484
df = DataFrame(np.random.randn(1000))
481485
labels = np.random.randint(0, 50, size=1000).astype(float)
482486

483-
for op, targop in ops:
484-
result = getattr(df.groupby(labels), op)().astype(float)
485-
expected = df.groupby(labels).agg(targop)
486-
try:
487-
tm.assert_frame_equal(result, expected)
488-
except BaseException as exc:
489-
exc.args += ('operation: %s' % op, )
490-
raise
487+
result = getattr(df.groupby(labels), op)().astype(float)
488+
expected = df.groupby(labels).agg(targop)
489+
tm.assert_frame_equal(result, expected)
491490

492491

493492
def test_max_nan_bug():

pandas/tests/groupby/test_nth.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,21 @@ def test_first_last_tz(data, expected_first, expected_last):
282282
])
283283
def test_first_last_tz_multi_column(method, ts, alpha):
284284
# GH 21603
285+
category_string = pd.Series(list('abc')).astype(
286+
'category')
285287
df = pd.DataFrame({'group': [1, 1, 2],
286-
'category_string': pd.Series(list('abc')).astype(
287-
'category'),
288+
'category_string': category_string,
288289
'datetimetz': pd.date_range('20130101', periods=3,
289290
tz='US/Eastern')})
290291
result = getattr(df.groupby('group'), method)()
291-
expepcted = pd.DataFrame({'category_string': [alpha, 'c'],
292-
'datetimetz': [ts,
293-
Timestamp('2013-01-03',
294-
tz='US/Eastern')]},
295-
index=pd.Index([1, 2], name='group'))
296-
assert_frame_equal(result, expepcted)
292+
expected = pd.DataFrame(
293+
{'category_string': pd.Categorical(
294+
[alpha, 'c'], dtype=category_string.dtype),
295+
'datetimetz': [ts,
296+
Timestamp('2013-01-03',
297+
tz='US/Eastern')]},
298+
index=pd.Index([1, 2], name='group'))
299+
assert_frame_equal(result, expected)
297300

298301

299302
def test_nth_multi_index_as_expected():

pandas/tests/resample/test_datetime_index.py

+6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ def test_resample_integerarray():
112112
dtype="Int64")
113113
assert_series_equal(result, expected)
114114

115+
result = ts.resample('3T').mean()
116+
expected = Series([1, 4, 7],
117+
index=pd.date_range('1/1/2000', periods=3, freq='3T'),
118+
dtype='Int64')
119+
assert_series_equal(result, expected)
120+
115121

116122
def test_resample_basic_grouper(series):
117123
s = series

0 commit comments

Comments
 (0)