Skip to content

Commit 5bec6b3

Browse files
committed
BUG: preserve categorical & sparse types when grouping / pivot
preserve dtypes when applying a ufunc to a sparse dtype closes pandas-dev#18502 closes pandas-dev#23743
1 parent b9b081d commit 5bec6b3

File tree

17 files changed

+282
-80
lines changed

17 files changed

+282
-80
lines changed

doc/source/whatsnew/v0.25.0.rst

+59
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,65 @@ of ``object`` dtype. :attr:`Series.str` will now infer the dtype data *within* t
316316
s
317317
s.str.startswith(b'a')
318318
319+
<<<<<<< HEAD
320+
=======
321+
.. _whatsnew_0250.api_breaking.ufuncs:
322+
323+
ufuncs on Extension Dtype
324+
^^^^^^^^^^^^^^^^^^^^^^^^^
325+
326+
Operations with ``numpy`` ufuncs on Extension Arrays, including Sparse Dtypes will now coerce the
327+
resulting dtypes to same as the input dtype; previously this would coerce to a dense dtype. (:issue:`23743`)
328+
329+
.. ipython:: python
330+
331+
df = pd.DataFrame({'A': pd.Series([1, np.nan, 3], dtype=pd.SparseDtype('float64', np.nan))})
332+
df
333+
df.dtypes
334+
335+
*Previous Behavior*:
336+
337+
.. code-block:: python
338+
339+
In [3]: np.sqrt(df).dtypes
340+
Out[3]:
341+
A float64
342+
dtype: object
343+
344+
*New Behavior*:
345+
346+
.. ipython:: python
347+
348+
np.sqrt(df).dtypes
349+
350+
.. _whatsnew_0250.api_breaking.groupby_categorical:
351+
352+
Categorical dtypes are preserved during groupby
353+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
354+
355+
Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)
356+
357+
.. ipython:: python
358+
359+
df = pd.DataFrame({'payload': [-1,-2,-1,-2],
360+
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
361+
df
362+
df.dtypes
363+
364+
*Previous Behavior*:
365+
366+
.. code-block:: python
367+
368+
In [5]: df.groupby('payload').first().col.dtype
369+
Out[5]: dtype('O')
370+
371+
*New Behavior*:
372+
373+
.. ipython:: python
374+
375+
df.groupby('payload').first().col.dtype
376+
377+
319378
.. _whatsnew_0250.api_breaking.incompatible_index_unions:
320379

321380
Incompatible Index Type Unions

pandas/core/dtypes/cast.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def conv(r, dtype):
605605
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]
606606

607607

608-
def astype_nansafe(arr, dtype, copy=True, skipna=False):
608+
def astype_nansafe(arr, dtype, copy=True, skipna=False, casting='unsafe'):
609609
"""
610610
Cast the elements of an array to a given dtype a nan-safe manner.
611611
@@ -616,8 +616,10 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
616616
copy : bool, default True
617617
If False, a view will be attempted but may fail, if
618618
e.g. the item sizes don't align.
619-
skipna: bool, default False
619+
skipna : bool, default False
620620
Whether or not we should skip NaN when casting as a string-type.
621+
casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}
622+
optional, default 'unsafe'
621623
622624
Raises
623625
------
@@ -703,7 +705,7 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
703705

704706
if copy or is_object_dtype(arr) or is_object_dtype(dtype):
705707
# Explicit copy, or required since NumPy can't view from / to object.
706-
return arr.astype(dtype, copy=True)
708+
return arr.astype(dtype, copy=True, casting=casting)
707709

708710
return arr.view(dtype)
709711

pandas/core/frame.py

+45
Original file line numberDiff line numberDiff line change
@@ -2641,6 +2641,51 @@ def transpose(self, *args, **kwargs):
26412641

26422642
T = property(transpose)
26432643

2644+
# ----------------------------------------------------------------------
2645+
# Array Interface
2646+
2647+
# This is also set in IndexOpsMixin
2648+
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
2649+
__array_priority__ = 1000
2650+
2651+
def __array__(self, dtype=None):
2652+
return com.values_from_object(self)
2653+
2654+
def __array_wrap__(self, result: np.ndarray, context=None) -> 'DataFrame':
2655+
"""
2656+
We are called post ufunc; reconstruct the original object and dtypes.
2657+
2658+
Parameters
2659+
----------
2660+
result : np.ndarray
2661+
context
2662+
2663+
Returns
2664+
-------
2665+
DataFrame
2666+
"""
2667+
2668+
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
2669+
result = self._constructor(result, **d)
2670+
2671+
# we try to cast extension array types back to the original
2672+
# TODO: this fails with duplicates, ugh
2673+
if self._data.any_extension_types:
2674+
result = result.astype(self.dtypes,
2675+
copy=False,
2676+
errors='ignore',
2677+
casting='same_kind')
2678+
2679+
return result.__finalize__(self)
2680+
2681+
# ideally we would define this to avoid the getattr checks, but
2682+
# is slower
2683+
# @property
2684+
# def __array_interface__(self):
2685+
# """ provide numpy array interface method """
2686+
# values = self.values
2687+
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)
2688+
26442689
# ----------------------------------------------------------------------
26452690
# Picklability
26462691

pandas/core/generic.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -1919,25 +1919,6 @@ def empty(self):
19191919
# ----------------------------------------------------------------------
19201920
# Array Interface
19211921

1922-
# This is also set in IndexOpsMixin
1923-
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
1924-
__array_priority__ = 1000
1925-
1926-
def __array__(self, dtype=None):
1927-
return com.values_from_object(self)
1928-
1929-
def __array_wrap__(self, result, context=None):
1930-
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
1931-
return self._constructor(result, **d).__finalize__(self)
1932-
1933-
# ideally we would define this to avoid the getattr checks, but
1934-
# is slower
1935-
# @property
1936-
# def __array_interface__(self):
1937-
# """ provide numpy array interface method """
1938-
# values = self.values
1939-
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)
1940-
19411922
def to_dense(self):
19421923
"""
19431924
Return dense representation of NDFrame (as opposed to sparse).
@@ -5693,6 +5674,11 @@ def astype(self, dtype, copy=True, errors='raise', **kwargs):
56935674
**kwargs)
56945675
return self._constructor(new_data).__finalize__(self)
56955676

5677+
if not results:
5678+
if copy:
5679+
self = self.copy()
5680+
return self
5681+
56965682
# GH 19920: retain column metadata after concat
56975683
result = pd.concat(results, axis=1, copy=False)
56985684
result.columns = self.columns

pandas/core/groupby/generic.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,
156156

157157
obj = self.obj[data.items[locs]]
158158
s = groupby(obj, self.grouper)
159-
result = s.aggregate(lambda x: alt(x, axis=self.axis))
159+
try:
160+
result = s.aggregate(lambda x: alt(x, axis=self.axis))
161+
except Exception:
162+
# we may have an exception in trying to aggregate
163+
# continue and exclude the block
164+
pass
160165

161166
finally:
162167

168+
dtype = block.values.dtype
169+
163170
# see if we can cast the block back to the original dtype
164-
result = block._try_coerce_and_cast_result(result)
171+
result = block._try_coerce_and_cast_result(result, dtype=dtype)
165172
newb = block.make_block(result)
166173

167174
new_items.append(locs)

pandas/core/groupby/groupby.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
786786
elif is_extension_array_dtype(dtype):
787787
# The function can return something of any type, so check
788788
# if the type is compatible with the calling EA.
789+
790+
# return the same type (Series) as our caller
789791
try:
790792
result = obj._values._from_sequence(result, dtype=dtype)
791793
except Exception:
@@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
11571159
"""
11581160
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
11591161
try:
1160-
return self._cython_agg_general('mean', **kwargs)
1162+
return self._cython_agg_general(
1163+
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
11611164
except GroupByError:
11621165
raise
11631166
except Exception: # pragma: no cover
@@ -1179,7 +1182,11 @@ def median(self, **kwargs):
11791182
Median of values within each group.
11801183
"""
11811184
try:
1182-
return self._cython_agg_general('median', **kwargs)
1185+
return self._cython_agg_general(
1186+
'median',
1187+
alt=lambda x,
1188+
axis: Series(x).median(**kwargs),
1189+
**kwargs)
11831190
except GroupByError:
11841191
raise
11851192
except Exception: # pragma: no cover
@@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
12351242
nv.validate_groupby_func('var', args, kwargs)
12361243
if ddof == 1:
12371244
try:
1238-
return self._cython_agg_general('var', **kwargs)
1245+
return self._cython_agg_general(
1246+
'var',
1247+
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
1248+
**kwargs)
12391249
except Exception:
12401250
f = lambda x: x.var(ddof=ddof, **kwargs)
12411251
with _group_selection_context(self):
@@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
12631273
Series or DataFrame
12641274
Standard error of the mean of values within each group.
12651275
"""
1266-
12671276
return self.std(ddof=ddof) / np.sqrt(self.count())
12681277

12691278
@Substitution(name='groupby')
@@ -1320,6 +1329,16 @@ def f(self, **kwargs):
13201329
except Exception:
13211330
result = self.aggregate(
13221331
lambda x: npfunc(x, axis=self.axis))
1332+
1333+
# coerce the columns if we can
1334+
if isinstance(result, DataFrame):
1335+
for col in result.columns:
1336+
result[col] = self._try_cast(
1337+
result[col], self.obj[col])
1338+
else:
1339+
result = self._try_cast(
1340+
result, self.obj)
1341+
13231342
if _convert:
13241343
result = result._convert(datetime=True)
13251344
return result

pandas/core/groupby/ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pandas.core.dtypes.common import (
2020
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
2121
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
22-
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
22+
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
2323
is_timedelta64_dtype, needs_i8_conversion)
2424
from pandas.core.dtypes.missing import _maybe_fill, isna
2525

@@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,
451451

452452
# categoricals are only 1d, so we
453453
# are not setup for dim transforming
454-
if is_categorical_dtype(values):
454+
if is_categorical_dtype(values) or is_sparse(values):
455455
raise NotImplementedError(
456-
"categoricals are not support in cython ops ATM")
456+
"{} are not support in cython ops".format(values.dtype))
457457
elif is_datetime64_any_dtype(values):
458458
if how in ['add', 'prod', 'cumsum', 'cumprod']:
459459
raise NotImplementedError(

pandas/core/internals/blocks.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
600600
values = self.get_values(dtype=dtype)
601601

602602
# _astype_nansafe works fine with 1-d only
603-
values = astype_nansafe(values.ravel(), dtype, copy=True)
603+
values = astype_nansafe(
604+
values.ravel(), dtype, copy=True, **kwargs)
604605

605606
# TODO(extension)
606607
# should we make this attribute?
@@ -1767,6 +1768,19 @@ def _slice(self, slicer):
17671768

17681769
return self.values[slicer]
17691770

1771+
def _try_cast_result(self, result, dtype=None):
1772+
"""
1773+
if we have an operation that operates on for example floats
1774+
we want to try to cast back to our EA here if possible
1775+
"""
1776+
try:
1777+
result = self._holder._from_sequence(
1778+
np.asarray(result).ravel(), dtype=dtype)
1779+
except Exception:
1780+
pass
1781+
1782+
return result
1783+
17701784
def formatting_values(self):
17711785
# Deprecating the ability to override _formatting_values.
17721786
# Do the warning here, it's only user in pandas, since we

pandas/core/nanops.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ def _f(*args, **kwargs):
7272

7373
class bottleneck_switch:
7474

75-
def __init__(self, **kwargs):
75+
def __init__(self, name=None, **kwargs):
76+
self.name = name
7677
self.kwargs = kwargs
7778

7879
def __call__(self, alt):
79-
bn_name = alt.__name__
80+
bn_name = self.name or alt.__name__
8081

8182
try:
8283
bn_func = getattr(bn, bn_name)
@@ -804,7 +805,8 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):
804805

805806

806807
def _nanminmax(meth, fill_value_typ):
807-
@bottleneck_switch()
808+
809+
@bottleneck_switch(name='nan' + meth)
808810
def reduction(values, axis=None, skipna=True, mask=None):
809811

810812
values, mask, dtype, dtype_max, fill_value = _get_values(
@@ -824,7 +826,6 @@ def reduction(values, axis=None, skipna=True, mask=None):
824826
result = _wrap_results(result, dtype, fill_value)
825827
return _maybe_null_out(result, axis, mask, values.shape)
826828

827-
reduction.__name__ = 'nan' + meth
828829
return reduction
829830

830831

pandas/core/series.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -762,12 +762,31 @@ def __array__(self, dtype=None):
762762
dtype = 'M8[ns]'
763763
return np.asarray(self.array, dtype)
764764

765-
def __array_wrap__(self, result, context=None):
765+
def __array_wrap__(self, result: np.ndarray, context=None) -> 'Series':
766766
"""
767-
Gets called after a ufunc.
767+
We are called post ufunc; reconstruct the original object and dtypes.
768+
769+
Parameters
770+
----------
771+
result : np.ndarray
772+
context
773+
774+
Returns
775+
-------
776+
Series
768777
"""
769-
return self._constructor(result, index=self.index,
770-
copy=False).__finalize__(self)
778+
779+
result = self._constructor(result, index=self.index,
780+
copy=False)
781+
782+
# we try to cast extension array types back to the original
783+
if is_extension_array_dtype(self):
784+
result = result.astype(self.dtype,
785+
copy=False,
786+
errors='ignore',
787+
casting='same_kind')
788+
789+
return result.__finalize__(self)
771790

772791
def __array_prepare__(self, result, context=None):
773792
"""

0 commit comments

Comments
 (0)