-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: retain masked EA dtypes in groupby with as_index=False #41373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
eaf4fa9
02005f4
93ed7e7
9fbd50f
d4a986a
9b1e560
fcfc4e4
ff5d851
cf6b37d
a617ee4
67aa1c6
03cd407
c9db010
238123b
446a77e
77e90ca
802d18d
f767706
4c83e3e
b7a8599
cb37f57
ac4402b
30b07f9
f019783
c882029
e22cf97
efe5976
6eca84d
a32cb83
0d64355
ecc3151
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -908,6 +908,23 @@ def reconstructed_codes(self) -> list[np.ndarray]: | |
ids, obs_ids, _ = self.group_info | ||
return decons_obs_group_ids(ids, obs_ids, self.shape, codes, xnull=True) | ||
|
||
@cache_readonly | ||
def result_arraylike(self) -> ArrayLike: | ||
""" | ||
Analogous to result_index, but returning an ndarray/ExtensionArray | ||
allowing us to retain ExtensionDtypes not supported by Index. | ||
""" | ||
# TODO: once Index supports arbitrary EAs, this can be removed in favor | ||
# of result_index | ||
if len(self.groupings) == 1: | ||
return self.groupings[0].result_arraylike | ||
|
||
codes = self.reconstructed_codes | ||
levels = [ping.result_arraylike for ping in self.groupings] | ||
return MultiIndex( | ||
levels=levels, codes=codes, verify_integrity=False, names=self.names | ||
)._values | ||
|
||
@cache_readonly | ||
def result_index(self) -> Index: | ||
if len(self.groupings) == 1: | ||
|
@@ -924,12 +941,12 @@ def get_group_levels(self) -> list[Index]: | |
# Note: only called from _insert_inaxis_grouper_inplace, which | ||
# is only called for BaseGrouper, never for BinGrouper | ||
if len(self.groupings) == 1: | ||
return [self.groupings[0].result_index] | ||
return [self.groupings[0].result_arraylike] | ||
|
||
name_list = [] | ||
for ping, codes in zip(self.groupings, self.reconstructed_codes): | ||
codes = ensure_platform_int(codes) | ||
levels = ping.result_index.take(codes) | ||
levels = ping.result_arraylike.take(codes) | ||
|
||
name_list.append(levels) | ||
|
||
|
@@ -991,7 +1008,10 @@ def agg_series(self, obj: Series, func: F) -> ArrayLike: | |
result = self._aggregate_series_fast(obj, func) | ||
cast_back = False | ||
|
||
npvalues = lib.maybe_convert_objects(result, try_float=False) | ||
convert_datetime = obj.dtype.kind == "M" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what breaks if we remove this inference entirely? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. huh, nothing now. im pretty sure there was something back when i did this. will revert There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated + green |
||
npvalues = lib.maybe_convert_objects( | ||
result, try_float=False, convert_datetime=convert_datetime | ||
) | ||
if cast_back: | ||
# TODO: Is there a documented reason why we dont always cast_back? | ||
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -269,14 +269,14 @@ def test_grouping_grouper(self, data_for_grouping): | |
def test_groupby_extension_agg(self, as_index, data_for_grouping): | ||
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping}) | ||
result = df.groupby("B", as_index=as_index).A.mean() | ||
_, index = pd.factorize(data_for_grouping, sort=True) | ||
_, uniques = pd.factorize(data_for_grouping, sort=True) | ||
|
||
index = pd.Index(index, name="B") | ||
expected = pd.Series([3, 1], index=index, name="A") | ||
if as_index: | ||
index = pd.Index(uniques, name="B") | ||
expected = pd.Series([3, 1], index=index, name="A") | ||
self.assert_series_equal(result, expected) | ||
else: | ||
expected = expected.reset_index() | ||
expected = pd.DataFrame({"B": uniques, "A": [3, 1]}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a user facing change right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
self.assert_frame_equal(result, expected) | ||
|
||
def test_groupby_agg_extension(self, data_for_grouping): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you really need this state? seems very magical here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i agree, the statefulness is unpleasant. #41375 starts to unwind it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you try to unwind first? this is adding a lot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, let’s get 41375 in and then I’ll rebased and try to trim this down