Skip to content

Commit ce86c21

Browse files
authored
BUG: preserve categorical & sparse types when grouping / pivot (#27071)
1 parent de0867f commit ce86c21

File tree

15 files changed

+205
-71
lines changed

15 files changed

+205
-71
lines changed

doc/source/whatsnew/v0.25.0.rst

+30-1
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,35 @@ of ``object`` dtype. :attr:`Series.str` will now infer the dtype data *within* t
322322
s
323323
s.str.startswith(b'a')
324324
325+
.. _whatsnew_0250.api_breaking.groupby_categorical:
326+
327+
Categorical dtypes are preserved during groupby
328+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
329+
330+
Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)
331+
332+
.. ipython:: python
333+
334+
df = pd.DataFrame(
335+
{'payload': [-1, -2, -1, -2],
336+
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
337+
df
338+
df.dtypes
339+
340+
*Previous Behavior*:
341+
342+
.. code-block:: python
343+
344+
In [5]: df.groupby('payload').first().col.dtype
345+
Out[5]: dtype('O')
346+
347+
*New Behavior*:
348+
349+
.. ipython:: python
350+
351+
df.groupby('payload').first().col.dtype
352+
353+
325354
.. _whatsnew_0250.api_breaking.incompatible_index_unions:
326355

327356
Incompatible Index type unions
@@ -809,7 +838,7 @@ ExtensionArray
809838

810839
- Bug in :func:`factorize` when passing an ``ExtensionArray`` with a custom ``na_sentinel`` (:issue:`25696`).
811840
- :meth:`Series.count` miscounts NA values in ExtensionArrays (:issue:`26835`)
812-
- Keyword argument ``deep`` has been removed from :method:`ExtensionArray.copy` (:issue:`27083`)
841+
- Keyword argument ``deep`` has been removed from :meth:`ExtensionArray.copy` (:issue:`27083`)
813842

814843
Other
815844
^^^^^

pandas/core/groupby/generic.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,
158158

159159
obj = self.obj[data.items[locs]]
160160
s = groupby(obj, self.grouper)
161-
result = s.aggregate(lambda x: alt(x, axis=self.axis))
161+
try:
162+
result = s.aggregate(lambda x: alt(x, axis=self.axis))
163+
except TypeError:
164+
# we may have an exception in trying to aggregate
165+
# continue and exclude the block
166+
pass
162167

163168
finally:
164169

170+
dtype = block.values.dtype
171+
165172
# see if we can cast the block back to the original dtype
166-
result = block._try_coerce_and_cast_result(result)
173+
result = block._try_coerce_and_cast_result(result, dtype=dtype)
167174
newb = block.make_block(result)
168175

169176
new_items.append(locs)

pandas/core/groupby/groupby.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
786786
elif is_extension_array_dtype(dtype):
787787
# The function can return something of any type, so check
788788
# if the type is compatible with the calling EA.
789+
790+
# return the same type (Series) as our caller
789791
try:
790792
result = obj._values._from_sequence(result, dtype=dtype)
791793
except Exception:
@@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
11571159
"""
11581160
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
11591161
try:
1160-
return self._cython_agg_general('mean', **kwargs)
1162+
return self._cython_agg_general(
1163+
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
11611164
except GroupByError:
11621165
raise
11631166
except Exception: # pragma: no cover
@@ -1179,7 +1182,11 @@ def median(self, **kwargs):
11791182
Median of values within each group.
11801183
"""
11811184
try:
1182-
return self._cython_agg_general('median', **kwargs)
1185+
return self._cython_agg_general(
1186+
'median',
1187+
alt=lambda x,
1188+
axis: Series(x).median(axis=axis, **kwargs),
1189+
**kwargs)
11831190
except GroupByError:
11841191
raise
11851192
except Exception: # pragma: no cover
@@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
12351242
nv.validate_groupby_func('var', args, kwargs)
12361243
if ddof == 1:
12371244
try:
1238-
return self._cython_agg_general('var', **kwargs)
1245+
return self._cython_agg_general(
1246+
'var',
1247+
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
1248+
**kwargs)
12391249
except Exception:
12401250
f = lambda x: x.var(ddof=ddof, **kwargs)
12411251
with _group_selection_context(self):
@@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
12631273
Series or DataFrame
12641274
Standard error of the mean of values within each group.
12651275
"""
1266-
12671276
return self.std(ddof=ddof) / np.sqrt(self.count())
12681277

12691278
@Substitution(name='groupby')
@@ -1290,7 +1299,7 @@ def _add_numeric_operations(cls):
12901299
"""
12911300

12921301
def groupby_function(name, alias, npfunc,
1293-
numeric_only=True, _convert=False,
1302+
numeric_only=True,
12941303
min_count=-1):
12951304

12961305
_local_template = """
@@ -1312,17 +1321,30 @@ def f(self, **kwargs):
13121321
kwargs['min_count'] = min_count
13131322

13141323
self._set_group_selection()
1324+
1325+
# try a cython aggregation if we can
13151326
try:
13161327
return self._cython_agg_general(
13171328
alias, alt=npfunc, **kwargs)
13181329
except AssertionError as e:
13191330
raise SpecificationError(str(e))
13201331
except Exception:
1321-
result = self.aggregate(
1322-
lambda x: npfunc(x, axis=self.axis))
1323-
if _convert:
1324-
result = result._convert(datetime=True)
1325-
return result
1332+
pass
1333+
1334+
# apply a non-cython aggregation
1335+
result = self.aggregate(
1336+
lambda x: npfunc(x, axis=self.axis))
1337+
1338+
# coerce the resulting columns if we can
1339+
if isinstance(result, DataFrame):
1340+
for col in result.columns:
1341+
result[col] = self._try_cast(
1342+
result[col], self.obj[col])
1343+
else:
1344+
result = self._try_cast(
1345+
result, self.obj)
1346+
1347+
return result
13261348

13271349
set_function_name(f, name, cls)
13281350

pandas/core/groupby/ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pandas.core.dtypes.common import (
2020
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
2121
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
22-
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
22+
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
2323
is_timedelta64_dtype, needs_i8_conversion)
2424
from pandas.core.dtypes.missing import _maybe_fill, isna
2525

@@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,
451451

452452
# categoricals are only 1d, so we
453453
# are not setup for dim transforming
454-
if is_categorical_dtype(values):
454+
if is_categorical_dtype(values) or is_sparse(values):
455455
raise NotImplementedError(
456-
"categoricals are not support in cython ops ATM")
456+
"{} are not support in cython ops".format(values.dtype))
457457
elif is_datetime64_any_dtype(values):
458458
if how in ['add', 'prod', 'cumsum', 'cumprod']:
459459
raise NotImplementedError(

pandas/core/internals/blocks.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
594594
values = self.get_values(dtype=dtype)
595595

596596
# _astype_nansafe works fine with 1-d only
597-
values = astype_nansafe(values.ravel(), dtype, copy=True)
597+
values = astype_nansafe(
598+
values.ravel(), dtype, copy=True, **kwargs)
598599

599600
# TODO(extension)
600601
# should we make this attribute?
@@ -1746,6 +1747,27 @@ def _slice(self, slicer):
17461747

17471748
return self.values[slicer]
17481749

1750+
def _try_cast_result(self, result, dtype=None):
1751+
"""
1752+
if we have an operation that operates on for example floats
1753+
we want to try to cast back to our EA here if possible
1754+
1755+
result could be a 2-D numpy array, e.g. the result of
1756+
a numeric operation; but it must be shape (1, X) because
1757+
we by-definition operate on the ExtensionBlocks one-by-one
1758+
1759+
result could also be an EA Array itself, in which case it
1760+
is already a 1-D array
1761+
"""
1762+
try:
1763+
1764+
result = self._holder._from_sequence(
1765+
result.ravel(), dtype=dtype)
1766+
except Exception:
1767+
pass
1768+
1769+
return result
1770+
17491771
def formatting_values(self):
17501772
# Deprecating the ability to override _formatting_values.
17511773
# Do the warning here, it's only user in pandas, since we

pandas/core/internals/construction.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,10 @@ def sanitize_array(data, index, dtype=None, copy=False,
687687
data = np.array(data, dtype=dtype, copy=False)
688688
subarr = np.array(data, dtype=object, copy=copy)
689689

690-
if is_object_dtype(subarr.dtype) and dtype != 'object':
690+
if (not (is_extension_array_dtype(subarr.dtype) or
691+
is_extension_array_dtype(dtype)) and
692+
is_object_dtype(subarr.dtype) and
693+
not is_object_dtype(dtype)):
691694
inferred = lib.infer_dtype(subarr, skipna=False)
692695
if inferred == 'period':
693696
try:

pandas/core/nanops.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ def _f(*args, **kwargs):
7272

7373
class bottleneck_switch:
7474

75-
def __init__(self, **kwargs):
75+
def __init__(self, name=None, **kwargs):
76+
self.name = name
7677
self.kwargs = kwargs
7778

7879
def __call__(self, alt):
79-
bn_name = alt.__name__
80+
bn_name = self.name or alt.__name__
8081

8182
try:
8283
bn_func = getattr(bn, bn_name)
@@ -804,7 +805,8 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):
804805

805806

806807
def _nanminmax(meth, fill_value_typ):
807-
@bottleneck_switch()
808+
809+
@bottleneck_switch(name='nan' + meth)
808810
def reduction(values, axis=None, skipna=True, mask=None):
809811

810812
values, mask, dtype, dtype_max, fill_value = _get_values(
@@ -824,7 +826,6 @@ def reduction(values, axis=None, skipna=True, mask=None):
824826
result = _wrap_results(result, dtype, fill_value)
825827
return _maybe_null_out(result, axis, mask, values.shape)
826828

827-
reduction.__name__ = 'nan' + meth
828829
return reduction
829830

830831

pandas/tests/extension/base/groupby.py

+12
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ def test_groupby_extension_apply(
6464
df.groupby("A").apply(groupby_apply_op)
6565
df.groupby("A").B.apply(groupby_apply_op)
6666

67+
def test_groupby_apply_identity(self, data_for_grouping):
68+
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
69+
"B": data_for_grouping})
70+
result = df.groupby('A').B.apply(lambda x: x.array)
71+
expected = pd.Series([df.B.iloc[[0, 1, 6]].array,
72+
df.B.iloc[[2, 3]].array,
73+
df.B.iloc[[4, 5]].array,
74+
df.B.iloc[[7]].array],
75+
index=pd.Index([1, 2, 3, 4], name='A'),
76+
name='B')
77+
self.assert_series_equal(result, expected)
78+
6779
def test_in_numeric_groupby(self, data_for_grouping):
6880
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
6981
"B": data_for_grouping,

pandas/tests/extension/decimal/test_decimal.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@ class TestCasting(BaseDecimal, base.BaseCastingTests):
192192

193193

194194
class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
195-
pass
195+
196+
@pytest.mark.xfail(
197+
reason="needs to correctly define __eq__ to handle nans, xref #27081.")
198+
def test_groupby_apply_identity(self, data_for_grouping):
199+
super().test_groupby_apply_identity(data_for_grouping)
196200

197201

198202
class TestSetitem(BaseDecimal, base.BaseSetitemTests):

pandas/tests/groupby/test_categorical.py

+21
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,27 @@ def test_preserve_categorical_dtype():
697697
tm.assert_frame_equal(result2, expected)
698698

699699

700+
@pytest.mark.parametrize(
701+
'func, values',
702+
[('first', ['second', 'first']),
703+
('last', ['fourth', 'third']),
704+
('min', ['fourth', 'first']),
705+
('max', ['second', 'third'])])
706+
def test_preserve_on_ordered_ops(func, values):
707+
# gh-18502
708+
# preserve the categoricals on ops
709+
c = pd.Categorical(['first', 'second', 'third', 'fourth'], ordered=True)
710+
df = pd.DataFrame(
711+
{'payload': [-1, -2, -1, -2],
712+
'col': c})
713+
g = df.groupby('payload')
714+
result = getattr(g, func)()
715+
expected = pd.DataFrame(
716+
{'payload': [-2, -1],
717+
'col': pd.Series(values, dtype=c.dtype)}).set_index('payload')
718+
tm.assert_frame_equal(result, expected)
719+
720+
700721
def test_categorical_no_compress():
701722
data = Series(np.random.randn(9))
702723

0 commit comments

Comments
 (0)