Skip to content

Commit aba1467

Browse files
committed
Aligned group_fillna and group_shift signatures
1 parent fae5707 commit aba1467

File tree

2 files changed

+75
-37
lines changed

2 files changed

+75
-37
lines changed

pandas/_libs/groupby_helper.pxi.in

+11-9
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ def group_cumsum(numeric[:, :] out,
906906

907907
@cython.boundscheck(False)
908908
@cython.wraparound(False)
909-
def group_shift_indexer(int64_t[:] out, int64_t[:] labels,
909+
def group_shift_indexer(ndarray[int64_t] out, ndarray[int64_t] labels,
910910
int ngroups, int periods):
911911
cdef:
912912
Py_ssize_t N, i, j, ii
@@ -957,21 +957,19 @@ def group_shift_indexer(int64_t[:] out, int64_t[:] labels,
957957

958958
@cython.wraparound(False)
959959
@cython.boundscheck(False)
960-
def group_fillna_indexer(ndarray[int64_t] out,
961-
ndarray[uint8_t] mask,
962-
ndarray[int64_t] labels,
963-
object method,
960+
def group_fillna_indexer(ndarray[int64_t] out, ndarray[int64_t] labels,
961+
ndarray[uint8_t] mask, object direction,
964962
int64_t limit):
965-
"""Fills values forwards or backwards within a group
963+
"""Indexes how to fill values forwards or backwards within a group
966964

967965
Parameters
968966
----------
969967
out : array of int64_t values which this method will write its results to
970968
Missing values will be written to with a value of -1
971-
mask : array of int64_t values where a 1 indicates a missing value
972969
labels : array containing unique label for each group, with its ordering
973970
matching up to the corresponding record in `values`
974-
method : {'ffill', 'bfill'}
971+
mask : array of int64_t values where a 1 indicates a missing value
972+
direction : {'ffill', 'bfill'}
975973
Direction for fill to be applied (forwards or backwards, respectively)
976974
limit : Consecutive values to fill before stopping, or -1 for no limit
977975

@@ -987,8 +985,11 @@ def group_fillna_indexer(ndarray[int64_t] out,
987985

988986
N = len(out)
989987

988+
# Make sure all arrays are the same size
989+
assert N == len(labels) == len(mask)
990+
990991
sorted_labels = np.argsort(labels).astype(np.int64, copy=False)
991-
if method == 'bfill':
992+
if direction == 'bfill':
992993
sorted_labels = sorted_labels[::-1]
993994

994995
with nogil:
@@ -1004,6 +1005,7 @@ def group_fillna_indexer(ndarray[int64_t] out,
10041005
curr_fill_idx = idx
10051006

10061007
out[idx] = curr_fill_idx
1008+
10071009
# If we move to the next group, reset
10081010
# the fill_idx and counter
10091011
if i == N - 1 or labels[idx] != labels[sorted_labels[i+1]]:

pandas/core/groupby.py

+64-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import types
2-
from functools import wraps
2+
from functools import wraps, partial
33
import numpy as np
44
import datetime
55
import collections
@@ -1457,25 +1457,14 @@ def expanding(self, *args, **kwargs):
14571457
from pandas.core.window import ExpandingGroupby
14581458
return ExpandingGroupby(self, *args, **kwargs)
14591459

1460-
def _fill(self, how, limit=None):
1461-
labels, _, _ = self.grouper.group_info
1462-
1460+
def _fill(self, direction, limit=None):
14631461
# Need int value for Cython
14641462
if limit is None:
14651463
limit = -1
1466-
output = {}
1467-
if type(self) is DataFrameGroupBy:
1468-
for grp in self.grouper.groupings:
1469-
ser = grp.group_index.take(grp.labels)
1470-
output[ser.name] = ser.values
1471-
for name, obj in self._iterate_slices():
1472-
indexer = np.zeros_like(labels)
1473-
mask = isnull(obj.values).view(np.uint8)
1474-
libgroupby.group_fillna_indexer(indexer, mask, labels, how,
1475-
limit)
1476-
output[name] = algorithms.take_nd(obj.values, indexer)
14771464

1478-
return self._wrap_transformed_output(output)
1465+
return self._get_cythonized_result('group_fillna_indexer',
1466+
self.grouper, needs_mask=True,
1467+
direction=direction, limit=limit)
14791468

14801469
@Substitution(name='groupby')
14811470
def pad(self, limit=None):
@@ -1863,6 +1852,52 @@ def cummax(self, axis=0, **kwargs):
18631852

18641853
return self._cython_transform('cummax', numeric_only=False)
18651854

1855+
def _get_cythonized_result(self, how, grouper, needs_mask=False,
1856+
needs_ngroups=False, **kwargs):
1857+
"""Get result for Cythonized functions
1858+
1859+
Parameters
1860+
----------
1861+
how : str, Cythonized function name to be called
1862+
grouper : Grouper object containing pertinent group info
1863+
needs_mask : bool, default False
1864+
Whether boolean mask needs to be part of the Cython call signature
1865+
needs_ngroups : bool, default False
1866+
Whether number of groups part of the Cython call signature
1867+
**kwargs : dict
1868+
Extra arguments required for the given function. This method
1869+
internally stores an OrderedDict that maps those keywords to
1870+
positional arguments before calling the Cython layer
1871+
1872+
Returns
1873+
-------
1874+
GroupBy object populated with appropriate result(s)
1875+
"""
1876+
exp_kwds = collections.OrderedDict([
1877+
(('group_fillna_indexer'), ('direction', 'limit')),
1878+
(('group_shift_indexer'), ('nperiods',))])
1879+
1880+
labels, _, ngroups = grouper.group_info
1881+
output = collections.OrderedDict()
1882+
base_func = getattr(libgroupby, how)
1883+
1884+
for name, obj in self._iterate_slices():
1885+
indexer = np.zeros_like(labels)
1886+
func = partial(base_func, indexer, labels)
1887+
if needs_mask:
1888+
mask = isnull(obj.values).astype(np.uint8, copy=False)
1889+
func = partial(func, mask)
1890+
1891+
if needs_ngroups:
1892+
func = partial(func, ngroups)
1893+
1894+
# Convert any keywords into positional arguments
1895+
func = partial(func, *(kwargs[x] for x in exp_kwds[how]))
1896+
func() # Call func to modify indexer values in place
1897+
output[name] = algorithms.take_nd(obj.values, indexer)
1898+
1899+
return self._wrap_transformed_output(output)
1900+
18661901
@Substitution(name='groupby')
18671902
@Appender(_doc_template)
18681903
def shift(self, periods=1, freq=None, axis=0):
@@ -1880,17 +1915,10 @@ def shift(self, periods=1, freq=None, axis=0):
18801915
if freq is not None or axis != 0:
18811916
return self.apply(lambda x: x.shift(periods, freq, axis))
18821917

1883-
labels, _, ngroups = self.grouper.group_info
1884-
1885-
# filled in by Cython
1886-
indexer = np.zeros_like(labels)
1887-
libgroupby.group_shift_indexer(indexer, labels, ngroups, periods)
1918+
return self._get_cythonized_result('group_shift_indexer',
1919+
self.grouper, needs_ngroups=True,
1920+
nperiods=periods)
18881921

1889-
output = {}
1890-
for name, obj in self._iterate_slices():
1891-
output[name] = algorithms.take_nd(obj.values, indexer)
1892-
1893-
return self._wrap_transformed_output(output)
18941922

18951923
@Substitution(name='groupby')
18961924
@Appender(_doc_template)
@@ -3597,7 +3625,6 @@ def describe(self, **kwargs):
35973625
def value_counts(self, normalize=False, sort=True, ascending=False,
35983626
bins=None, dropna=True):
35993627

3600-
from functools import partial
36013628
from pandas.core.reshape.tile import cut
36023629
from pandas.core.reshape.merge import _get_join_indexers
36033630

@@ -4605,9 +4632,18 @@ def _apply_to_column_groupbys(self, func):
46054632
in self._iterate_column_groupbys()),
46064633
keys=self._selected_obj.columns, axis=1)
46074634

4635+
def _fill(self, direction, limit=None):
4636+
"""Overriden method to concat grouped columns in output"""
4637+
res = super()._fill(direction, limit=limit)
4638+
output = collections.OrderedDict()
4639+
for grp in self.grouper.groupings:
4640+
ser = grp.group_index.take(grp.labels)
4641+
output[ser.name] = ser.values
4642+
4643+
return self._wrap_transformed_output(output).join(res)
4644+
46084645
def count(self):
46094646
""" Compute count of group, excluding missing values """
4610-
from functools import partial
46114647
from pandas.core.dtypes.missing import _isna_ndarraylike as isna
46124648

46134649
data, _ = self._get_data_to_aggregate()

0 commit comments

Comments
 (0)