Skip to content

Commit b6a3eae

Browse files
committed
BUG: preserve categorical & sparse types when grouping / pivot
preserve dtypes when applying a ufunc to a sparse dtype closes pandas-dev#18502 closes pandas-dev#23743
1 parent a91da0c commit b6a3eae

File tree

17 files changed

+283
-81
lines changed

17 files changed

+283
-81
lines changed

doc/source/whatsnew/v0.25.0.rst

+60-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,64 @@ returned if all the columns were dummy encoded, and a :class:`DataFrame` otherwi
154154
Providing any ``SparseSeries`` or ``SparseDataFrame`` to :func:`concat` will
155155
cause a ``SparseSeries`` or ``SparseDataFrame`` to be returned, as before.
156156

157-
.. _whatsnew_0250.api_breaking.incompatible_index_unions
157+
.. _whatsnew_0250.api_breaking.ufuncs:
158+
159+
ufuncs on Extension Dtype
160+
^^^^^^^^^^^^^^^^^^^^^^^^^
161+
162+
Operations with ``numpy`` ufuncs on Extension Arrays, including Sparse Dtypes will now coerce the
163+
resulting dtypes to same as the input dtype; previously this would coerce to a dense dtype. (:issue:`23743`)
164+
165+
.. ipython:: python
166+
167+
df = pd.DataFrame({'A': pd.Series([1, np.nan, 3], dtype=pd.SparseDtype('float64', np.nan))})
168+
df
169+
df.dtypes
170+
171+
*Previous Behavior*:
172+
173+
.. code-block:: python
174+
175+
In [3]: np.sqrt(df).dtypes
176+
Out[3]:
177+
A float64
178+
dtype: object
179+
180+
*New Behavior*:
181+
182+
.. ipython:: python
183+
184+
np.sqrt(df).dtypes
185+
186+
.. _whatsnew_0250.api_breaking.groupby_categorical:
187+
188+
Categorical dtypes are preserved during groupby
189+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
190+
191+
Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)
192+
193+
.. ipython:: python
194+
195+
df = pd.DataFrame({'payload': [-1,-2,-1,-2],
196+
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
197+
df
198+
df.dtypes
199+
200+
*Previous Behavior*:
201+
202+
.. code-block:: python
203+
204+
In [5]: df.groupby('payload').first().col.dtype
205+
Out[5]: dtype('O')
206+
207+
*New Behavior*:
208+
209+
.. ipython:: python
210+
211+
df.groupby('payload').first().col.dtype
212+
213+
214+
.. _whatsnew_0250.api_breaking.incompatible_index_unions:
158215

159216
Incompatible Index Type Unions
160217
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -168,6 +225,8 @@ considered commutative, such that ``A.union(B) == B.union(A)`` (:issue:`23525`).
168225

169226
*Previous Behavior*:
170227

228+
.. code-block:: python
229+
171230
In [1]: pd.period_range('19910905', periods=2).union(pd.Int64Index([1, 2, 3]))
172231
...
173232
ValueError: can only call with other PeriodIndex-ed objects

pandas/core/dtypes/cast.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def conv(r, dtype):
606606
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]
607607

608608

609-
def astype_nansafe(arr, dtype, copy=True, skipna=False):
609+
def astype_nansafe(arr, dtype, copy=True, skipna=False, casting='unsafe'):
610610
"""
611611
Cast the elements of an array to a given dtype a nan-safe manner.
612612
@@ -617,8 +617,10 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
617617
copy : bool, default True
618618
If False, a view will be attempted but may fail, if
619619
e.g. the item sizes don't align.
620-
skipna: bool, default False
620+
skipna : bool, default False
621621
Whether or not we should skip NaN when casting as a string-type.
622+
casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}
623+
optional, default 'unsafe'
622624
623625
Raises
624626
------
@@ -704,7 +706,7 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
704706

705707
if copy or is_object_dtype(arr) or is_object_dtype(dtype):
706708
# Explicit copy, or required since NumPy can't view from / to object.
707-
return arr.astype(dtype, copy=True)
709+
return arr.astype(dtype, copy=True, casting=casting)
708710

709711
return arr.view(dtype)
710712

pandas/core/frame.py

+45
Original file line numberDiff line numberDiff line change
@@ -2634,6 +2634,51 @@ def transpose(self, *args, **kwargs):
26342634

26352635
T = property(transpose)
26362636

2637+
# ----------------------------------------------------------------------
2638+
# Array Interface
2639+
2640+
# This is also set in IndexOpsMixin
2641+
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
2642+
__array_priority__ = 1000
2643+
2644+
def __array__(self, dtype=None):
2645+
return com.values_from_object(self)
2646+
2647+
def __array_wrap__(self, result: np.ndarray, context=None) -> 'DataFrame':
2648+
"""
2649+
We are called post ufunc; reconstruct the original object and dtypes.
2650+
2651+
Parameters
2652+
----------
2653+
result : np.ndarray
2654+
context
2655+
2656+
Returns
2657+
-------
2658+
DataFrame
2659+
"""
2660+
2661+
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
2662+
result = self._constructor(result, **d)
2663+
2664+
# we try to cast extension array types back to the original
2665+
# TODO: this fails with duplicates, ugh
2666+
if self._data.any_extension_types:
2667+
result = result.astype(self.dtypes,
2668+
copy=False,
2669+
errors='ignore',
2670+
casting='same_kind')
2671+
2672+
return result.__finalize__(self)
2673+
2674+
# ideally we would define this to avoid the getattr checks, but
2675+
# is slower
2676+
# @property
2677+
# def __array_interface__(self):
2678+
# """ provide numpy array interface method """
2679+
# values = self.values
2680+
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)
2681+
26372682
# ----------------------------------------------------------------------
26382683
# Picklability
26392684

pandas/core/generic.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -1941,25 +1941,6 @@ def empty(self):
19411941
# ----------------------------------------------------------------------
19421942
# Array Interface
19431943

1944-
# This is also set in IndexOpsMixin
1945-
# GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented
1946-
__array_priority__ = 1000
1947-
1948-
def __array__(self, dtype=None):
1949-
return com.values_from_object(self)
1950-
1951-
def __array_wrap__(self, result, context=None):
1952-
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
1953-
return self._constructor(result, **d).__finalize__(self)
1954-
1955-
# ideally we would define this to avoid the getattr checks, but
1956-
# is slower
1957-
# @property
1958-
# def __array_interface__(self):
1959-
# """ provide numpy array interface method """
1960-
# values = self.values
1961-
# return dict(typestr=values.dtype.str,shape=values.shape,data=values)
1962-
19631944
def to_dense(self):
19641945
"""
19651946
Return dense representation of NDFrame (as opposed to sparse).
@@ -5755,6 +5736,11 @@ def astype(self, dtype, copy=True, errors='raise', **kwargs):
57555736
**kwargs)
57565737
return self._constructor(new_data).__finalize__(self)
57575738

5739+
if not results:
5740+
if copy:
5741+
self = self.copy()
5742+
return self
5743+
57585744
# GH 19920: retain column metadata after concat
57595745
result = pd.concat(results, axis=1, copy=False)
57605746
result.columns = self.columns

pandas/core/groupby/generic.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,
104104

105105
obj = self.obj[data.items[locs]]
106106
s = groupby(obj, self.grouper)
107-
result = s.aggregate(lambda x: alt(x, axis=self.axis))
107+
try:
108+
result = s.aggregate(lambda x: alt(x, axis=self.axis))
109+
except Exception:
110+
# we may have an exception in trying to aggregate
111+
# continue and exclude the block
112+
pass
108113

109114
finally:
110115

116+
dtype = block.values.dtype
117+
111118
# see if we can cast the block back to the original dtype
112-
result = block._try_coerce_and_cast_result(result)
119+
result = block._try_coerce_and_cast_result(result, dtype=dtype)
113120
newb = block.make_block(result)
114121

115122
new_items.append(locs)

pandas/core/groupby/groupby.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,8 @@ def _try_cast(self, result, obj, numeric_only=False):
784784
elif is_extension_array_dtype(dtype):
785785
# The function can return something of any type, so check
786786
# if the type is compatible with the calling EA.
787+
788+
# return the same type (Series) as our caller
787789
try:
788790
result = obj._values._from_sequence(result, dtype=dtype)
789791
except Exception:
@@ -1155,7 +1157,8 @@ def mean(self, *args, **kwargs):
11551157
"""
11561158
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
11571159
try:
1158-
return self._cython_agg_general('mean', **kwargs)
1160+
return self._cython_agg_general(
1161+
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
11591162
except GroupByError:
11601163
raise
11611164
except Exception: # pragma: no cover
@@ -1177,7 +1180,11 @@ def median(self, **kwargs):
11771180
Median of values within each group.
11781181
"""
11791182
try:
1180-
return self._cython_agg_general('median', **kwargs)
1183+
return self._cython_agg_general(
1184+
'median',
1185+
alt=lambda x,
1186+
axis: Series(x).median(**kwargs),
1187+
**kwargs)
11811188
except GroupByError:
11821189
raise
11831190
except Exception: # pragma: no cover
@@ -1233,7 +1240,10 @@ def var(self, ddof=1, *args, **kwargs):
12331240
nv.validate_groupby_func('var', args, kwargs)
12341241
if ddof == 1:
12351242
try:
1236-
return self._cython_agg_general('var', **kwargs)
1243+
return self._cython_agg_general(
1244+
'var',
1245+
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
1246+
**kwargs)
12371247
except Exception:
12381248
f = lambda x: x.var(ddof=ddof, **kwargs)
12391249
with _group_selection_context(self):
@@ -1261,7 +1271,6 @@ def sem(self, ddof=1):
12611271
Series or DataFrame
12621272
Standard error of the mean of values within each group.
12631273
"""
1264-
12651274
return self.std(ddof=ddof) / np.sqrt(self.count())
12661275

12671276
@Substitution(name='groupby')
@@ -1318,6 +1327,16 @@ def f(self, **kwargs):
13181327
except Exception:
13191328
result = self.aggregate(
13201329
lambda x: npfunc(x, axis=self.axis))
1330+
1331+
# coerce the columns if we can
1332+
if isinstance(result, DataFrame):
1333+
for col in result.columns:
1334+
result[col] = self._try_cast(
1335+
result[col], self.obj[col])
1336+
else:
1337+
result = self._try_cast(
1338+
result, self.obj)
1339+
13211340
if _convert:
13221341
result = result._convert(datetime=True)
13231342
return result

pandas/core/groupby/ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pandas.core.dtypes.common import (
2020
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
2121
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
22-
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
22+
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
2323
is_timedelta64_dtype, needs_i8_conversion)
2424
from pandas.core.dtypes.missing import _maybe_fill, isna
2525

@@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,
451451

452452
# categoricals are only 1d, so we
453453
# are not setup for dim transforming
454-
if is_categorical_dtype(values):
454+
if is_categorical_dtype(values) or is_sparse(values):
455455
raise NotImplementedError(
456-
"categoricals are not support in cython ops ATM")
456+
"{} are not support in cython ops".format(values.dtype))
457457
elif is_datetime64_any_dtype(values):
458458
if how in ['add', 'prod', 'cumsum', 'cumprod']:
459459
raise NotImplementedError(

pandas/core/internals/blocks.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
604604
values = self.get_values(dtype=dtype)
605605

606606
# _astype_nansafe works fine with 1-d only
607-
values = astype_nansafe(values.ravel(), dtype, copy=True)
607+
values = astype_nansafe(
608+
values.ravel(), dtype, copy=True, **kwargs)
608609

609610
# TODO(extension)
610611
# should we make this attribute?
@@ -1771,6 +1772,19 @@ def _slice(self, slicer):
17711772

17721773
return self.values[slicer]
17731774

1775+
def _try_cast_result(self, result, dtype=None):
1776+
"""
1777+
if we have an operation that operates on for example floats
1778+
we want to try to cast back to our EA here if possible
1779+
"""
1780+
try:
1781+
result = self._holder._from_sequence(
1782+
np.asarray(result).ravel(), dtype=dtype)
1783+
except Exception:
1784+
pass
1785+
1786+
return result
1787+
17741788
def formatting_values(self):
17751789
# Deprecating the ability to override _formatting_values.
17761790
# Do the warning here, it's only user in pandas, since we

pandas/core/nanops.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,12 @@ def _f(*args, **kwargs):
8989

9090
class bottleneck_switch:
9191

92-
def __init__(self, **kwargs):
92+
def __init__(self, name=None, **kwargs):
93+
self.name = name
9394
self.kwargs = kwargs
9495

9596
def __call__(self, alt):
96-
bn_name = alt.__name__
97+
bn_name = self.name or alt.__name__
9798

9899
try:
99100
bn_func = getattr(bn, bn_name)
@@ -821,7 +822,8 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):
821822

822823

823824
def _nanminmax(meth, fill_value_typ):
824-
@bottleneck_switch()
825+
826+
@bottleneck_switch(name='nan' + meth)
825827
def reduction(values, axis=None, skipna=True, mask=None):
826828

827829
values, mask, dtype, dtype_max, fill_value = _get_values(
@@ -841,7 +843,6 @@ def reduction(values, axis=None, skipna=True, mask=None):
841843
result = _wrap_results(result, dtype, fill_value)
842844
return _maybe_null_out(result, axis, mask, values.shape)
843845

844-
reduction.__name__ = 'nan' + meth
845846
return reduction
846847

847848

0 commit comments

Comments
 (0)