Skip to content

Commit cbeb91e

Browse files
committed
BUG: preserve categorical & sparse types when grouping / pivot
preserve dtypes when applying a ufunc to a sparse dtype closes pandas-dev#18502 closes pandas-dev#23743
1 parent 0f3e8e8 commit cbeb91e

File tree

17 files changed

+281
-81
lines changed

17 files changed

+281
-81
lines changed

doc/source/whatsnew/v0.25.0.rst

+58-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,64 @@ of ``object`` dtype. :attr:`Series.str` will now infer the dtype data *within* t
268268
s
269269
s.str.startswith(b'a')
270270
271-
.. _whatsnew_0250.api_breaking.incompatible_index_unions
271+
.. _whatsnew_0250.api_breaking.ufuncs:
272+
273+
ufuncs on Extension Dtype
274+
^^^^^^^^^^^^^^^^^^^^^^^^^
275+
276+
Operations with ``numpy`` ufuncs on Extension Arrays, including Sparse Dtypes will now coerce the
277+
resulting dtypes to same as the input dtype; previously this would coerce to a dense dtype. (:issue:`23743`)
278+
279+
.. ipython:: python
280+
281+
df = pd.DataFrame({'A': pd.Series([1, np.nan, 3], dtype=pd.SparseDtype('float64', np.nan))})
282+
df
283+
df.dtypes
284+
285+
*Previous Behavior*:
286+
287+
.. code-block:: python
288+
289+
In [3]: np.sqrt(df).dtypes
290+
Out[3]:
291+
A float64
292+
dtype: object
293+
294+
*New Behavior*:
295+
296+
.. ipython:: python
297+
298+
np.sqrt(df).dtypes
299+
300+
.. _whatsnew_0250.api_breaking.groupby_categorical:
301+
302+
Categorical dtypes are preserved during groupby
303+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
304+
305+
Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)
306+
307+
.. ipython:: python
308+
309+
df = pd.DataFrame({'payload': [-1,-2,-1,-2],
310+
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
311+
df
312+
df.dtypes
313+
314+
*Previous Behavior*:
315+
316+
.. code-block:: python
317+
318+
In [5]: df.groupby('payload').first().col.dtype
319+
Out[5]: dtype('O')
320+
321+
*New Behavior*:
322+
323+
.. ipython:: python
324+
325+
df.groupby('payload').first().col.dtype
326+
327+
328+
.. _whatsnew_0250.api_breaking.incompatible_index_unions:
272329

273330
Incompatible Index Type Unions
274331
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

pandas/core/dtypes/cast.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def conv(r, dtype):
606606
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]
607607

608608

609-
def astype_nansafe(arr, dtype, copy=True, skipna=False):
609+
def astype_nansafe(arr, dtype, copy=True, skipna=False, casting='unsafe'):
610610
"""
611611
Cast the elements of an array to a given dtype a nan-safe manner.
612612
@@ -617,8 +617,10 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
617617
copy : bool, default True
618618
If False, a view will be attempted but may fail, if
619619
e.g. the item sizes don't align.
620-
skipna: bool, default False
620+
skipna : bool, default False
621621
Whether or not we should skip NaN when casting as a string-type.
622+
casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}
623+
optional, default 'unsafe'
622624
623625
Raises
624626
------
@@ -704,7 +706,7 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
704706

705707
if copy or is_object_dtype(arr) or is_object_dtype(dtype):
706708
# Explicit copy, or required since NumPy can't view from / to object.
707-
return arr.astype(dtype, copy=True)
709+
return arr.astype(dtype, copy=True, casting=casting)
708710

709711
return arr.view(dtype)
710712

pandas/core/frame.py

+45
Original file line numberDiff line numberDiff line change
@@ -2634,6 +2634,51 @@ def transpose(self, *args, **kwargs):
26342634

26352635
T = property(transpose)
26362636

2637+
# ----------------------------------------------------------------------
2638+
# Array Interface
2639+
2640+
# This is also set in IndexOpsMixin
2641+
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
2642+
__array_priority__ = 1000
2643+
2644+
def __array__(self, dtype=None):
2645+
return com.values_from_object(self)
2646+
2647+
def __array_wrap__(self, result: np.ndarray, context=None) -> 'DataFrame':
2648+
"""
2649+
We are called post ufunc; reconstruct the original object and dtypes.
2650+
2651+
Parameters
2652+
----------
2653+
result : np.ndarray
2654+
context
2655+
2656+
Returns
2657+
-------
2658+
DataFrame
2659+
"""
2660+
2661+
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
2662+
result = self._constructor(result, **d)
2663+
2664+
# we try to cast extension array types back to the original
2665+
# TODO: this fails with duplicates, ugh
2666+
if self._data.any_extension_types:
2667+
result = result.astype(self.dtypes,
2668+
copy=False,
2669+
errors='ignore',
2670+
casting='same_kind')
2671+
2672+
return result.__finalize__(self)
2673+
2674+
# ideally we would define this to avoid the getattr checks, but
2675+
# is slower
2676+
# @property
2677+
# def __array_interface__(self):
2678+
# """ provide numpy array interface method """
2679+
# values = self.values
2680+
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)
2681+
26372682
# ----------------------------------------------------------------------
26382683
# Picklability
26392684

pandas/core/generic.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -1910,25 +1910,6 @@ def empty(self):
19101910
# ----------------------------------------------------------------------
19111911
# Array Interface
19121912

1913-
# This is also set in IndexOpsMixin
1914-
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
1915-
__array_priority__ = 1000
1916-
1917-
def __array__(self, dtype=None):
1918-
return com.values_from_object(self)
1919-
1920-
def __array_wrap__(self, result, context=None):
1921-
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
1922-
return self._constructor(result, **d).__finalize__(self)
1923-
1924-
# ideally we would define this to avoid the getattr checks, but
1925-
# is slower
1926-
# @property
1927-
# def __array_interface__(self):
1928-
# """ provide numpy array interface method """
1929-
# values = self.values
1930-
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)
1931-
19321913
def to_dense(self):
19331914
"""
19341915
Return dense representation of NDFrame (as opposed to sparse).
@@ -5679,6 +5660,11 @@ def astype(self, dtype, copy=True, errors='raise', **kwargs):
56795660
**kwargs)
56805661
return self._constructor(new_data).__finalize__(self)
56815662

5663+
if not results:
5664+
if copy:
5665+
self = self.copy()
5666+
return self
5667+
56825668
# GH 19920: retain column metadata after concat
56835669
result = pd.concat(results, axis=1, copy=False)
56845670
result.columns = self.columns

pandas/core/groupby/generic.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,
110110

111111
obj = self.obj[data.items[locs]]
112112
s = groupby(obj, self.grouper)
113-
result = s.aggregate(lambda x: alt(x, axis=self.axis))
113+
try:
114+
result = s.aggregate(lambda x: alt(x, axis=self.axis))
115+
except Exception:
116+
# we may have an exception in trying to aggregate
117+
# continue and exclude the block
118+
pass
114119

115120
finally:
116121

122+
dtype = block.values.dtype
123+
117124
# see if we can cast the block back to the original dtype
118-
result = block._try_coerce_and_cast_result(result)
125+
result = block._try_coerce_and_cast_result(result, dtype=dtype)
119126
newb = block.make_block(result)
120127

121128
new_items.append(locs)

pandas/core/groupby/groupby.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,8 @@ def _try_cast(self, result, obj, numeric_only=False):
785785
elif is_extension_array_dtype(dtype):
786786
# The function can return something of any type, so check
787787
# if the type is compatible with the calling EA.
788+
789+
# return the same type (Series) as our caller
788790
try:
789791
result = obj._values._from_sequence(result, dtype=dtype)
790792
except Exception:
@@ -1156,7 +1158,8 @@ def mean(self, *args, **kwargs):
11561158
"""
11571159
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
11581160
try:
1159-
return self._cython_agg_general('mean', **kwargs)
1161+
return self._cython_agg_general(
1162+
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
11601163
except GroupByError:
11611164
raise
11621165
except Exception: # pragma: no cover
@@ -1178,7 +1181,11 @@ def median(self, **kwargs):
11781181
Median of values within each group.
11791182
"""
11801183
try:
1181-
return self._cython_agg_general('median', **kwargs)
1184+
return self._cython_agg_general(
1185+
'median',
1186+
alt=lambda x,
1187+
axis: Series(x).median(**kwargs),
1188+
**kwargs)
11821189
except GroupByError:
11831190
raise
11841191
except Exception: # pragma: no cover
@@ -1234,7 +1241,10 @@ def var(self, ddof=1, *args, **kwargs):
12341241
nv.validate_groupby_func('var', args, kwargs)
12351242
if ddof == 1:
12361243
try:
1237-
return self._cython_agg_general('var', **kwargs)
1244+
return self._cython_agg_general(
1245+
'var',
1246+
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
1247+
**kwargs)
12381248
except Exception:
12391249
f = lambda x: x.var(ddof=ddof, **kwargs)
12401250
with _group_selection_context(self):
@@ -1262,7 +1272,6 @@ def sem(self, ddof=1):
12621272
Series or DataFrame
12631273
Standard error of the mean of values within each group.
12641274
"""
1265-
12661275
return self.std(ddof=ddof) / np.sqrt(self.count())
12671276

12681277
@Substitution(name='groupby')
@@ -1319,6 +1328,16 @@ def f(self, **kwargs):
13191328
except Exception:
13201329
result = self.aggregate(
13211330
lambda x: npfunc(x, axis=self.axis))
1331+
1332+
# coerce the columns if we can
1333+
if isinstance(result, DataFrame):
1334+
for col in result.columns:
1335+
result[col] = self._try_cast(
1336+
result[col], self.obj[col])
1337+
else:
1338+
result = self._try_cast(
1339+
result, self.obj)
1340+
13221341
if _convert:
13231342
result = result._convert(datetime=True)
13241343
return result

pandas/core/groupby/ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pandas.core.dtypes.common import (
2020
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
2121
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
22-
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
22+
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
2323
is_timedelta64_dtype, needs_i8_conversion)
2424
from pandas.core.dtypes.missing import _maybe_fill, isna
2525

@@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,
451451

452452
# categoricals are only 1d, so we
453453
# are not setup for dim transforming
454-
if is_categorical_dtype(values):
454+
if is_categorical_dtype(values) or is_sparse(values):
455455
raise NotImplementedError(
456-
"categoricals are not support in cython ops ATM")
456+
"{} are not support in cython ops".format(values.dtype))
457457
elif is_datetime64_any_dtype(values):
458458
if how in ['add', 'prod', 'cumsum', 'cumprod']:
459459
raise NotImplementedError(

pandas/core/internals/blocks.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
604604
values = self.get_values(dtype=dtype)
605605

606606
# _astype_nansafe works fine with 1-d only
607-
values = astype_nansafe(values.ravel(), dtype, copy=True)
607+
values = astype_nansafe(
608+
values.ravel(), dtype, copy=True, **kwargs)
608609

609610
# TODO(extension)
610611
# should we make this attribute?
@@ -1771,6 +1772,19 @@ def _slice(self, slicer):
17711772

17721773
return self.values[slicer]
17731774

1775+
def _try_cast_result(self, result, dtype=None):
1776+
"""
1777+
if we have an operation that operates on for example floats
1778+
we want to try to cast back to our EA here if possible
1779+
"""
1780+
try:
1781+
result = self._holder._from_sequence(
1782+
np.asarray(result).ravel(), dtype=dtype)
1783+
except Exception:
1784+
pass
1785+
1786+
return result
1787+
17741788
def formatting_values(self):
17751789
# Deprecating the ability to override _formatting_values.
17761790
# Do the warning here, it's only user in pandas, since we

pandas/core/nanops.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,12 @@ def _f(*args, **kwargs):
8989

9090
class bottleneck_switch:
9191

92-
def __init__(self, **kwargs):
92+
def __init__(self, name=None, **kwargs):
93+
self.name = name
9394
self.kwargs = kwargs
9495

9596
def __call__(self, alt):
96-
bn_name = alt.__name__
97+
bn_name = self.name or alt.__name__
9798

9899
try:
99100
bn_func = getattr(bn, bn_name)
@@ -821,7 +822,8 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):
821822

822823

823824
def _nanminmax(meth, fill_value_typ):
824-
@bottleneck_switch()
825+
826+
@bottleneck_switch(name='nan' + meth)
825827
def reduction(values, axis=None, skipna=True, mask=None):
826828

827829
values, mask, dtype, dtype_max, fill_value = _get_values(
@@ -841,7 +843,6 @@ def reduction(values, axis=None, skipna=True, mask=None):
841843
result = _wrap_results(result, dtype, fill_value)
842844
return _maybe_null_out(result, axis, mask, values.shape)
843845

844-
reduction.__name__ = 'nan' + meth
845846
return reduction
846847

847848

pandas/core/series.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -745,12 +745,31 @@ def __array__(self, dtype=None):
745745
dtype = 'M8[ns]'
746746
return np.asarray(self.array, dtype)
747747

748-
def __array_wrap__(self, result, context=None):
748+
def __array_wrap__(self, result: np.ndarray, context=None) -> 'Series':
749749
"""
750-
Gets called after a ufunc.
750+
We are called post ufunc; reconstruct the original object and dtypes.
751+
752+
Parameters
753+
----------
754+
result : np.ndarray
755+
context
756+
757+
Returns
758+
-------
759+
Series
751760
"""
752-
return self._constructor(result, index=self.index,
753-
copy=False).__finalize__(self)
761+
762+
result = self._constructor(result, index=self.index,
763+
copy=False)
764+
765+
# we try to cast extension array types back to the original
766+
if is_extension_array_dtype(self):
767+
result = result.astype(self.dtype,
768+
copy=False,
769+
errors='ignore',
770+
casting='same_kind')
771+
772+
return result.__finalize__(self)
754773

755774
def __array_prepare__(self, result, context=None):
756775
"""

0 commit comments

Comments
 (0)