-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
WIP/BUG: Correct results for groupby(...).transform with null keys #45839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
2143c59
deb3edb
ad7bf3e
39e1438
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ class providing the base-class of operations. | |
) | ||
from pandas.util._exceptions import find_stack_level | ||
|
||
from pandas.core.dtypes.cast import ensure_dtype_can_hold_na | ||
from pandas.core.dtypes.common import ( | ||
is_bool_dtype, | ||
is_datetime64_dtype, | ||
|
@@ -108,6 +109,7 @@ class providing the base-class of operations. | |
CategoricalIndex, | ||
Index, | ||
MultiIndex, | ||
RangeIndex, | ||
) | ||
from pandas.core.internals.blocks import ensure_block_shape | ||
import pandas.core.sample as sample | ||
|
@@ -663,7 +665,17 @@ def get_converter(s): | |
converter = get_converter(index_sample) | ||
names = (converter(name) for name in names) | ||
|
||
return [self.indices.get(name, []) for name in names] | ||
# return [self.indices.get(name, []) for name in names] | ||
# self.indices is a dict and doesn't handle looking up nulls in the groups | ||
from pandas import ( | ||
Index, | ||
Series, | ||
) | ||
|
||
index = Index(self.indices.keys(), tupleize_cols=False) | ||
indices = Series(self.indices.values(), index=index) | ||
result = [indices[name] for name in names] | ||
return result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here The need to separate out the Index carefully is to handle odd cases like where cc @jbrockmendel @phofl if there are any ideas to solve this in some more direct way rather than using a Series. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not sure if this is relevant, but possibly related to #43943? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think in index.pyx we have some code for canonicalizing NA values, e.g. float("nan") -> np.nan, so that dict lookups are better-behaved. could be adapted/reused when defining There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks - this is precisely what I was looking for. However, I found a simpler way to use the codes directly; I've put up #45953. |
||
|
||
@final | ||
def _get_index(self, name): | ||
|
@@ -947,7 +959,12 @@ def curried(x): | |
if name in base.plotting_methods: | ||
return self.apply(curried) | ||
|
||
return self._python_apply_general(curried, self._obj_with_exclusions) | ||
result = self._python_apply_general(curried, self._obj_with_exclusions) | ||
if self.grouper.has_dropped_na and name in base.transformation_kernels: | ||
# result will have dropped rows due to nans, fill with null | ||
# and ensure index is ordered same as the input | ||
result = self._set_result_index_ordered(result) | ||
return result | ||
|
||
wrapper.__name__ = name | ||
return wrapper | ||
|
@@ -1086,7 +1103,6 @@ def _set_result_index_ordered( | |
) -> OutputFrameOrSeries: | ||
# set the result index on the passed values object and | ||
# return the new object, xref 8046 | ||
|
||
if self.grouper.is_monotonic: | ||
# shortcut if we have an already ordered grouper | ||
result.set_axis(self.obj._get_axis(self.axis), axis=self.axis, inplace=True) | ||
|
@@ -1098,16 +1114,12 @@ def _set_result_index_ordered( | |
) | ||
result.set_axis(original_positions, axis=self.axis, inplace=True) | ||
result = result.sort_index(axis=self.axis) | ||
|
||
dropped_rows = len(result.index) < len(self.obj.index) | ||
|
||
if dropped_rows: | ||
# get index by slicing original index according to original positions | ||
# slice drops attrs => use set_axis when no rows were dropped | ||
sorted_indexer = result.index | ||
result.index = self._selected_obj.index[sorted_indexer] | ||
else: | ||
result.set_axis(self.obj._get_axis(self.axis), axis=self.axis, inplace=True) | ||
obj_axis = self.obj._get_axis(self.axis) | ||
if self.grouper.has_dropped_na: | ||
# Fill in any missing values due to dropna - index here is integral | ||
# with values referring to the row of the input so can use RangeIndex | ||
result = result.reindex(RangeIndex(len(obj_axis)), axis=self.axis) | ||
result.set_axis(obj_axis, axis=self.axis, inplace=True) | ||
|
||
return result | ||
|
||
|
@@ -1650,6 +1662,11 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): | |
with com.temp_setattr(self, "observed", True): | ||
result = getattr(self, func)(*args, **kwargs) | ||
|
||
if func == "ngroup" and not self.grouper.has_dropped_na: | ||
# ngroup handles its own wrapping, as long as there aren't | ||
# dropped null keys | ||
return result | ||
|
||
if self._can_use_transform_fast(func, result): | ||
return self._wrap_transform_fast_result(result) | ||
|
||
|
@@ -1672,7 +1689,9 @@ def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT: | |
out = algorithms.take_nd(result._values, ids) | ||
output = obj._constructor(out, index=obj.index, name=obj.name) | ||
else: | ||
output = result.take(ids, axis=0) | ||
# Don't convert indices: negative indices need to give rise | ||
# to null values in the result | ||
output = result._take(ids, axis=0, convert_indices=False) | ||
output.index = obj.index | ||
return output | ||
|
||
|
@@ -1725,9 +1744,14 @@ def _cumcount_array(self, ascending: bool = True) -> np.ndarray: | |
else: | ||
out = np.repeat(out[np.r_[run[1:], True]], rep) - out | ||
|
||
if self.grouper.has_dropped_na: | ||
out = np.where(ids == -1, np.nan, out.astype(np.float64, copy=False)) | ||
else: | ||
out = out.astype(np.int64, copy=False) | ||
|
||
rev = np.empty(count, dtype=np.intp) | ||
rev[sorter] = np.arange(count, dtype=np.intp) | ||
return out[rev].astype(np.int64, copy=False) | ||
return out[rev] | ||
|
||
# ----------------------------------------------------------------- | ||
|
||
|
@@ -2556,7 +2580,12 @@ def blk_func(values: ArrayLike) -> ArrayLike: | |
# then there will be no -1s in indexer, so we can use | ||
# the original dtype (no need to ensure_dtype_can_hold_na) | ||
if isinstance(values, np.ndarray): | ||
out = np.empty(values.shape, dtype=values.dtype) | ||
dtype = values.dtype | ||
if self.grouper.has_dropped_na: | ||
# ...unless there are any dropped null groups, | ||
# these give rise to nan in the result | ||
dtype = ensure_dtype_can_hold_na(values.dtype) | ||
out = np.empty(values.shape, dtype=dtype) | ||
else: | ||
out = type(values)._empty(values.shape, dtype=values.dtype) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment about why this is two lines instead of just
indices = Series(self.indices)
? i assume its to avoid a MultiIndex