Skip to content

WIP: groupby skipna #41399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def group_add(complexfloating_t[:, ::1] out,
int64_t[::1] counts,
ndarray[complexfloating_t, ndim=2] values,
const intp_t[:] labels,
Py_ssize_t min_count=0) -> None:
Py_ssize_t min_count=0,
bint skipna=True) -> None:
"""
Only aggregates on axis=0 using Kahan summation
"""
Expand Down Expand Up @@ -520,6 +521,13 @@ def group_add(complexfloating_t[:, ::1] out,
t = sumx[lab, j] + y
compensation[lab, j] = t - sumx[lab, j] - y
sumx[lab, j] = t
# dont skip nan
elif skipna == False:
sumx[lab, j] = NAN
break
# skip nan
else:
continue

for i in range(ncounts):
for j in range(K):
Expand All @@ -535,7 +543,8 @@ def group_prod(floating[:, ::1] out,
int64_t[::1] counts,
ndarray[floating, ndim=2] values,
const intp_t[:] labels,
Py_ssize_t min_count=0) -> None:
Py_ssize_t min_count=0,
bint skipna=True) -> None:
"""
Only aggregates on axis=0
"""
Expand Down Expand Up @@ -568,6 +577,11 @@ def group_prod(floating[:, ::1] out,
if val == val:
nobs[lab, j] += 1
prodx[lab, j] *= val
elif skipna == False:
prodx[lab, j] = NAN
break
else:
continue

for i in range(ncounts):
for j in range(K):
Expand All @@ -585,6 +599,7 @@ def group_var(floating[:, ::1] out,
ndarray[floating, ndim=2] values,
const intp_t[:] labels,
Py_ssize_t min_count=-1,
bint skipna=True,
int64_t ddof=1) -> None:
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
Expand Down Expand Up @@ -622,6 +637,11 @@ def group_var(floating[:, ::1] out,
oldmean = mean[lab, j]
mean[lab, j] += (val - oldmean) / nobs[lab, j]
out[lab, j] += (val - mean[lab, j]) * (val - oldmean)
elif skipna == False:
out[lab, j] = NAN
break
else:
continue

for i in range(ncounts):
for j in range(K):
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,10 @@ def _aggregate_multiple_funcs(self, arg):
return self.obj._constructor_expanddim(output, columns=columns)

def _cython_agg_general(
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1, skipna: bool = True
):
output: dict[base.OutputKey, ArrayLike] = {}
# MAYUKH
# Ideally we would be able to enumerate self._iterate_slices and use
# the index from enumeration as the key of output, but ohlc in particular
# returns a (n x 4) array. Output requires 1D ndarrays as values, so we
Expand All @@ -361,7 +362,7 @@ def _cython_agg_general(
continue

result = self.grouper._cython_operation(
"aggregate", obj._values, how, axis=0, min_count=min_count
"aggregate", obj._values, how, axis=0, min_count=min_count, skipna=skipna
)
assert result.ndim == 1
key = base.OutputKey(label=name, position=idx)
Expand Down
11 changes: 7 additions & 4 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ def _agg_general(
*,
alias: str,
npfunc: Callable,
skipna=True,
):
with group_selection_context(self):
# try a cython aggregation if we can
Expand All @@ -1279,6 +1280,7 @@ def _agg_general(
alt=npfunc,
numeric_only=numeric_only,
min_count=min_count,
skipna=skipna
)
except DataError:
pass
Expand All @@ -1298,7 +1300,7 @@ def _agg_general(
return result.__finalize__(self.obj, method="groupby")

def _cython_agg_general(
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1, skipna: bool = False
):
raise AbstractMethodError(self)

Expand Down Expand Up @@ -1691,7 +1693,7 @@ def size(self) -> FrameOrSeriesUnion:

@final
@doc(_groupby_agg_method_template, fname="sum", no=True, mc=0)
def sum(self, numeric_only: bool = True, min_count: int = 0):
def sum(self, numeric_only: bool = True, min_count: int = 0, skipna=True):

# If we are grouping on categoricals we want unobserved categories to
# return zero, rather than the default of NaN which the reindexing in
Expand All @@ -1702,15 +1704,16 @@ def sum(self, numeric_only: bool = True, min_count: int = 0):
min_count=min_count,
alias="add",
npfunc=np.sum,
skipna=skipna
)

return self._reindex_output(result, fill_value=0)

@final
@doc(_groupby_agg_method_template, fname="prod", no=True, mc=0)
def prod(self, numeric_only: bool = True, min_count: int = 0):
def prod(self, numeric_only: bool = True, min_count: int = 0, skipna: bool = True):
return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod, skipna=skipna
)

@final
Expand Down
11 changes: 8 additions & 3 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,12 +707,14 @@ def _cython_operation(
how: str,
axis: int,
min_count: int = -1,
skipna: bool = True,
mask: np.ndarray | None = None,
**kwargs,
) -> ArrayLike:
"""
Returns the values of a cython operation.
"""
#MAYUKH
orig_values = values
assert kind in ["transform", "aggregate"]

Expand All @@ -726,6 +728,7 @@ def _cython_operation(
dtype = values.dtype
is_numeric = is_numeric_dtype(dtype)

#MAYUKH
cy_op = WrappedCythonOp(kind=kind, how=how)

# can we do this operation with our cython functions
Expand All @@ -736,11 +739,11 @@ def _cython_operation(
if is_extension_array_dtype(dtype):
if isinstance(values, BaseMaskedArray) and func_uses_mask:
return self._masked_ea_wrap_cython_operation(
cy_op, kind, values, how, axis, min_count, **kwargs
cy_op, kind, values, how, axis, min_count, skipna, **kwargs
)
else:
return self._ea_wrap_cython_operation(
cy_op, kind, values, how, axis, min_count, **kwargs
cy_op, kind, values, how, axis, min_count, skipna, **kwargs
)

elif values.ndim == 1:
Expand All @@ -752,6 +755,7 @@ def _cython_operation(
how=how,
axis=1,
min_count=min_count,
skipna=skipna,
mask=mask,
**kwargs,
)
Expand Down Expand Up @@ -802,7 +806,8 @@ def _cython_operation(
is_datetimelike=is_datetimelike,
)
else:
func(result, counts, values, comp_ids, min_count)
#MAYUKH
func(result, counts, values, comp_ids, min_count, skipna)
elif kind == "transform":
# TODO: min_count
if func_uses_mask:
Expand Down