Skip to content

BUG: Preserve subclassing with groupby operations #28573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.25.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Groupby/resample/rolling

- Bug incorrectly raising an ``IndexError`` when passing a list of quantiles to :meth:`pandas.core.groupby.DataFrameGroupBy.quantile` (:issue:`28113`).
- Bug in :meth:`pandas.core.groupby.GroupBy.shift`, :meth:`pandas.core.groupby.GroupBy.bfill` and :meth:`pandas.core.groupby.GroupBy.ffill` where timezone information would be dropped (:issue:`19995`, :issue:`27992`)
-
- Bug groupby/resampling where subclasses were not returned from groupby ops (:issue:`28330`)
-
-

Expand Down
73 changes: 41 additions & 32 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs):
if not len(result_columns) and errors is not None:
raise errors

return DataFrame(result, columns=result_columns)
return self.obj._constructor(result, columns=result_columns)

def _decide_output_index(self, output, labels):
if len(output) == len(labels):
Expand All @@ -356,7 +356,7 @@ def _decide_output_index(self, output, labels):

def _wrap_applied_output(self, keys, values, not_indexed_same=False):
if len(keys) == 0:
return DataFrame(index=keys)
return self.obj._constructor(index=keys)

key_names = self.grouper.names

Expand All @@ -372,7 +372,7 @@ def first_not_none(values):
if v is None:
# GH9684. If all values are None, then this will throw an error.
# We'd prefer it return an empty dataframe.
return DataFrame()
return self.obj._constructor()
elif isinstance(v, DataFrame):
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
elif self.grouper.groupings is not None:
Expand Down Expand Up @@ -401,7 +401,7 @@ def first_not_none(values):
# make Nones an empty object
v = first_not_none(values)
if v is None:
return DataFrame()
return self.obj._constructor()
elif isinstance(v, NDFrame):
values = [
x if x is not None else v._constructor(**v._construct_axes_dict())
Expand Down Expand Up @@ -467,7 +467,7 @@ def first_not_none(values):
or isinstance(key_index, MultiIndex)
):
stacked_values = np.vstack([np.asarray(v) for v in values])
result = DataFrame(
result = self.obj._constructor(
stacked_values, index=key_index, columns=index
)
else:
Expand All @@ -484,14 +484,16 @@ def first_not_none(values):
result.columns = index
else:
stacked_values = np.vstack([np.asarray(v) for v in values])
result = DataFrame(
result = self.obj._constructor(
stacked_values.T, index=v.index, columns=key_index
)

except (ValueError, AttributeError):
# GH1738: values is list of arrays of unequal lengths fall
# through to the outer else caluse
return Series(values, index=key_index, name=self._selection_name)
return self.obj._constructor_sliced(
values, index=key_index, name=self._selection_name
)

# if we have date/time like in the original, then coerce dates
# as we are stacking can easily have object dtypes here
Expand All @@ -510,7 +512,7 @@ def first_not_none(values):
# self._selection_name not passed through to Series as the
# result should not take the name of original selection
# of columns
return Series(values, index=key_index)._convert(
return self.obj._constructor_sliced(values, index=key_index)._convert(
datetime=True, coerce=coerce
)

Expand Down Expand Up @@ -554,7 +556,7 @@ def _transform_general(self, func, *args, **kwargs):
r.columns = group.columns
r.index = group.index
else:
r = DataFrame(
r = self.obj._constructor(
np.concatenate([res.values] * len(group.index)).reshape(
group.shape
),
Expand Down Expand Up @@ -681,7 +683,7 @@ def _transform_item_by_item(self, obj, wrapper):
if len(output) < len(obj.columns):
columns = columns.take(inds)

return DataFrame(output, index=obj.index, columns=columns)
return self.obj._constructor(output, index=obj.index, columns=columns)

def filter(self, func, dropna=True, *args, **kwargs):
"""
Expand All @@ -706,7 +708,7 @@ def filter(self, func, dropna=True, *args, **kwargs):

Examples
--------
>>> df = pd.DataFrame({'A' : ['foo', 'bar', 'foo', 'bar',
>>> df = pd.self.obj._constructor({'A' : ['foo', 'bar', 'foo', 'bar',
... 'foo', 'bar'],
... 'B' : [1, 2, 3, 4, 5, 6],
... 'C' : [2.0, 5., 8., 1., 2., 9.]})
Expand Down Expand Up @@ -875,7 +877,7 @@ def aggregate(self, func=None, *args, **kwargs):
result = self._aggregate_named(func, *args, **kwargs)

index = Index(sorted(result), name=self.grouper.names[0])
ret = Series(result, index=index)
ret = self.obj._constructor(result, index=index)

if not self.as_index: # pragma: no cover
print("Warning, ignoring as_index=True")
Expand Down Expand Up @@ -943,20 +945,21 @@ def _aggregate_multiple_funcs(self, arg, _level):
# let higher level handle
if _level:
return results
self.obj._constructor(results, columns=columns)

return DataFrame(results, columns=columns)
return self.obj._constructor_expanddim(results, columns=columns)

def _wrap_output(self, output, index, names=None):
""" common agg/transform wrapping logic """
output = output[self._selection_name]

if names is not None:
return DataFrame(output, index=index, columns=names)
return self.obj._constructor_expanddim(output, index=index, columns=names)
else:
name = self._selection_name
if name is None:
name = self._selected_obj.name
return Series(output, index=index, name=name)
return self.obj._constructor(output, index=index, name=name)

def _wrap_aggregated_output(self, output, names=None):
result = self._wrap_output(
Expand All @@ -970,7 +973,7 @@ def _wrap_transformed_output(self, output, names=None):
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
if len(keys) == 0:
# GH #6265
return Series([], name=self._selection_name, index=keys)
return self.obj._constructor([], name=self._selection_name, index=keys)

def _get_index():
if self.grouper.nkeys > 1:
Expand All @@ -982,7 +985,9 @@ def _get_index():
if isinstance(values[0], dict):
# GH #823 #24880
index = _get_index()
result = self._reindex_output(DataFrame(values, index=index))
result = self._reindex_output(
self.obj._constructor_expanddim(values, index=index)
)
# if self.observed is False,
# keep all-NaN rows created while re-indexing
result = result.stack(dropna=self.observed)
Expand All @@ -996,7 +1001,9 @@ def _get_index():
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
else:
# GH #6265 #24880
result = Series(data=values, index=_get_index(), name=self._selection_name)
result = self.obj._constructor(
data=values, index=_get_index(), name=self._selection_name
)
return self._reindex_output(result)

def _aggregate_named(self, func, *args, **kwargs):
Expand Down Expand Up @@ -1052,7 +1059,7 @@ def transform(self, func, *args, **kwargs):

result = concat(results).sort_index()
else:
result = Series()
result = self.obj._constructor()

# we will only try to coerce the result type if
# we have a numeric dtype, as these are *always* udfs
Expand All @@ -1078,7 +1085,7 @@ def _transform_fast(self, func, func_nm):
out = algorithms.take_1d(func()._values, ids)
if cast:
out = self._try_cast(out, self.obj)
return Series(out, index=self.obj.index, name=self.obj.name)
return self.obj._constructor(out, index=self.obj.index, name=self.obj.name)

def filter(self, func, dropna=True, *args, **kwargs): # noqa
"""
Expand Down Expand Up @@ -1193,7 +1200,7 @@ def nunique(self, dropna=True):
res, out = np.zeros(len(ri), dtype=out.dtype), res
res[ids[idx]] = out

return Series(res, index=ri, name=self._selection_name)
return self.obj._constructor(res, index=ri, name=self._selection_name)

@Appender(Series.describe.__doc__)
def describe(self, **kwargs):
Expand Down Expand Up @@ -1233,7 +1240,7 @@ def value_counts(
else:

# lab is a Categorical with categories an IntervalIndex
lab = cut(Series(val), bins, include_lowest=True)
lab = cut(self.obj._constructor(val), bins, include_lowest=True)
lev = lab.cat.categories
lab = lev.take(lab.cat.codes)
llab = lambda lab, inc: lab[inc]._multiindex.codes[-1]
Expand Down Expand Up @@ -1293,7 +1300,7 @@ def value_counts(

if is_integer_dtype(out):
out = ensure_int64(out)
return Series(out, index=mi, name=self._selection_name)
return self.obj._constructor(out, index=mi, name=self._selection_name)

# for compat. with libgroupby.value_counts need to ensure every
# bin is present at every index level, null filled with zeros
Expand Down Expand Up @@ -1322,7 +1329,7 @@ def value_counts(

if is_integer_dtype(out):
out = ensure_int64(out)
return Series(out, index=mi, name=self._selection_name)
return self.obj._constructor(out, index=mi, name=self._selection_name)

def count(self):
"""
Expand All @@ -1341,7 +1348,7 @@ def count(self):
minlength = ngroups or 0
out = np.bincount(ids[mask], minlength=minlength)

return Series(
return self.obj._constructor(
out,
index=self.grouper.result_index,
name=self._selection_name,
Expand Down Expand Up @@ -1513,9 +1520,11 @@ def _wrap_generic_output(self, result, obj):
result_index = self.grouper.levels[0]

if self.axis == 0:
return DataFrame(result, index=obj.columns, columns=result_index).T
return self.obj._constructor(
result, index=obj.columns, columns=result_index
).T
else:
return DataFrame(result, index=obj.index, columns=result_index)
return self.obj._constructor(result, index=obj.index, columns=result_index)

def _get_data_to_aggregate(self):
obj = self._obj_with_exclusions
Expand Down Expand Up @@ -1548,33 +1557,33 @@ def _wrap_aggregated_output(self, output, names=None):
output_keys = self._decide_output_index(output, agg_labels)

if not self.as_index:
result = DataFrame(output, columns=output_keys)
result = self.obj._constructor(output, columns=output_keys)
self._insert_inaxis_grouper_inplace(result)
result = result._consolidate()
else:
index = self.grouper.result_index
result = DataFrame(output, index=index, columns=output_keys)
result = self.obj._constructor(output, index=index, columns=output_keys)

if self.axis == 1:
result = result.T

return self._reindex_output(result)._convert(datetime=True)

def _wrap_transformed_output(self, output, names=None):
return DataFrame(output, index=self.obj.index)
return self.obj._constructor(output, index=self.obj.index)

def _wrap_agged_blocks(self, items, blocks):
if not self.as_index:
index = np.arange(blocks[0].values.shape[-1])
mgr = BlockManager(blocks, [items, index])
result = DataFrame(mgr)
result = self.obj._constructor(mgr)

self._insert_inaxis_grouper_inplace(result)
result = result._consolidate()
else:
index = self.grouper.result_index
mgr = BlockManager(blocks, [items, index])
result = DataFrame(mgr)
result = self.obj._constructor(mgr)

if self.axis == 1:
result = result.T
Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,34 @@ def func(dataf):
result = df.groupby("X", squeeze=False).count()
assert isinstance(result, DataFrame)

# https://github.com/pandas-dev/pandas/issues/28330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally we split these into several tests, but could be a followup

# Test groupby operations on subclassed dataframes/series
class ChildSeries(Series):
pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually these tests are different that the above ones, can you make a new test

class ChildDataFrame(DataFrame):
@property
def _constructor(self):
return ChildDataFrame

_constructor_sliced = ChildSeries

ChildSeries._constructor_expanddim = ChildDataFrame

cdf = ChildDataFrame(
[
{"val1": 1, "val2": 20},
{"val1": 1, "val2": 19},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also pls parameterize over as many functions that you can here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use reduction_func from conftest. Probably worth adding something similar for transforming functions as well

{"val1": 2, "val2": 27},
{"val1": 2, "val2": 12},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update to comments here

]
)
result = cdf.groupby("val1").sum()
assert isinstance(result, ChildDataFrame)
assert isinstance(result, DataFrame)
assert isinstance(result["val2"], ChildSeries)
assert isinstance(result["val2"], Series)


def test_inconsistent_return_type():
# GH5592
Expand Down