Skip to content

Commit aa3e611

Browse files
authored
BUG: Preserve Series/DataFrame subclasses through groupby operations (#33884)
1 parent f9809fe commit aa3e611

File tree

7 files changed

+171
-36
lines changed

7 files changed

+171
-36
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,7 @@ Groupby/resample/rolling
809809
- Bug in :meth:`DataFrame.groupby` where a ``ValueError`` would be raised when grouping by a categorical column with read-only categories and ``sort=False`` (:issue:`33410`)
810810
- Bug in :meth:`GroupBy.first` and :meth:`GroupBy.last` where None is not preserved in object dtype (:issue:`32800`)
811811
- Bug in :meth:`Rolling.min` and :meth:`Rolling.max`: Growing memory usage after multiple calls when using a fixed window (:issue:`30726`)
812+
- Bug in :meth:`GroupBy.agg`, :meth:`GroupBy.transform`, and :meth:`GroupBy.resample` where subclasses are not preserved (:issue:`28330`)
812813

813814
Reshaping
814815
^^^^^^^^^

pandas/core/groupby/generic.py

+47-31
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _aggregate_multiple_funcs(self, arg):
327327
# let higher level handle
328328
return results
329329

330-
return DataFrame(results, columns=columns)
330+
return self.obj._constructor_expanddim(results, columns=columns)
331331

332332
def _wrap_series_output(
333333
self, output: Mapping[base.OutputKey, Union[Series, np.ndarray]], index: Index,
@@ -356,10 +356,12 @@ def _wrap_series_output(
356356

357357
result: Union[Series, DataFrame]
358358
if len(output) > 1:
359-
result = DataFrame(indexed_output, index=index)
359+
result = self.obj._constructor_expanddim(indexed_output, index=index)
360360
result.columns = columns
361361
else:
362-
result = Series(indexed_output[0], index=index, name=columns[0])
362+
result = self.obj._constructor(
363+
indexed_output[0], index=index, name=columns[0]
364+
)
363365

364366
return result
365367

@@ -418,7 +420,9 @@ def _wrap_transformed_output(
418420
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
419421
if len(keys) == 0:
420422
# GH #6265
421-
return Series([], name=self._selection_name, index=keys, dtype=np.float64)
423+
return self.obj._constructor(
424+
[], name=self._selection_name, index=keys, dtype=np.float64
425+
)
422426

423427
def _get_index() -> Index:
424428
if self.grouper.nkeys > 1:
@@ -430,7 +434,9 @@ def _get_index() -> Index:
430434
if isinstance(values[0], dict):
431435
# GH #823 #24880
432436
index = _get_index()
433-
result = self._reindex_output(DataFrame(values, index=index))
437+
result = self._reindex_output(
438+
self.obj._constructor_expanddim(values, index=index)
439+
)
434440
# if self.observed is False,
435441
# keep all-NaN rows created while re-indexing
436442
result = result.stack(dropna=self.observed)
@@ -444,7 +450,9 @@ def _get_index() -> Index:
444450
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
445451
else:
446452
# GH #6265 #24880
447-
result = Series(data=values, index=_get_index(), name=self._selection_name)
453+
result = self.obj._constructor(
454+
data=values, index=_get_index(), name=self._selection_name
455+
)
448456
return self._reindex_output(result)
449457

450458
def _aggregate_named(self, func, *args, **kwargs):
@@ -520,7 +528,7 @@ def _transform_general(
520528

521529
result = concat(results).sort_index()
522530
else:
523-
result = Series(dtype=np.float64)
531+
result = self.obj._constructor(dtype=np.float64)
524532

525533
# we will only try to coerce the result type if
526534
# we have a numeric dtype, as these are *always* user-defined funcs
@@ -543,7 +551,7 @@ def _transform_fast(self, result, func_nm: str) -> Series:
543551
out = algorithms.take_1d(result._values, ids)
544552
if cast:
545553
out = maybe_cast_result(out, self.obj, how=func_nm)
546-
return Series(out, index=self.obj.index, name=self.obj.name)
554+
return self.obj._constructor(out, index=self.obj.index, name=self.obj.name)
547555

548556
def filter(self, func, dropna=True, *args, **kwargs):
549557
"""
@@ -644,7 +652,7 @@ def nunique(self, dropna: bool = True) -> Series:
644652
res, out = np.zeros(len(ri), dtype=out.dtype), res
645653
res[ids[idx]] = out
646654

647-
result = Series(res, index=ri, name=self._selection_name)
655+
result = self.obj._constructor(res, index=ri, name=self._selection_name)
648656
return self._reindex_output(result, fill_value=0)
649657

650658
@doc(Series.describe)
@@ -746,7 +754,7 @@ def value_counts(
746754

747755
if is_integer_dtype(out):
748756
out = ensure_int64(out)
749-
return Series(out, index=mi, name=self._selection_name)
757+
return self.obj._constructor(out, index=mi, name=self._selection_name)
750758

751759
# for compat. with libgroupby.value_counts need to ensure every
752760
# bin is present at every index level, null filled with zeros
@@ -778,7 +786,7 @@ def build_codes(lev_codes: np.ndarray) -> np.ndarray:
778786

779787
if is_integer_dtype(out):
780788
out = ensure_int64(out)
781-
return Series(out, index=mi, name=self._selection_name)
789+
return self.obj._constructor(out, index=mi, name=self._selection_name)
782790

783791
def count(self) -> Series:
784792
"""
@@ -797,7 +805,7 @@ def count(self) -> Series:
797805
minlength = ngroups or 0
798806
out = np.bincount(ids[mask], minlength=minlength)
799807

800-
result = Series(
808+
result = self.obj._constructor(
801809
out,
802810
index=self.grouper.result_index,
803811
name=self._selection_name,
@@ -1195,11 +1203,11 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
11951203
if cannot_agg:
11961204
result_columns = result_columns.drop(cannot_agg)
11971205

1198-
return DataFrame(result, columns=result_columns)
1206+
return self.obj._constructor(result, columns=result_columns)
11991207

12001208
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12011209
if len(keys) == 0:
1202-
return DataFrame(index=keys)
1210+
return self.obj._constructor(index=keys)
12031211

12041212
key_names = self.grouper.names
12051213

@@ -1209,7 +1217,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12091217
if first_not_none is None:
12101218
# GH9684. If all values are None, then this will throw an error.
12111219
# We'd prefer it return an empty dataframe.
1212-
return DataFrame()
1220+
return self.obj._constructor()
12131221
elif isinstance(first_not_none, DataFrame):
12141222
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
12151223
elif self.grouper.groupings is not None:
@@ -1240,13 +1248,13 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12401248

12411249
# make Nones an empty object
12421250
if first_not_none is None:
1243-
return DataFrame()
1251+
return self.obj._constructor()
12441252
elif isinstance(first_not_none, NDFrame):
12451253

12461254
# this is to silence a DeprecationWarning
12471255
# TODO: Remove when default dtype of empty Series is object
12481256
kwargs = first_not_none._construct_axes_dict()
1249-
if first_not_none._constructor is Series:
1257+
if isinstance(first_not_none, Series):
12501258
backup = create_series_with_explicit_dtype(
12511259
**kwargs, dtype_if_empty=object
12521260
)
@@ -1313,7 +1321,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13131321
or isinstance(key_index, MultiIndex)
13141322
):
13151323
stacked_values = np.vstack([np.asarray(v) for v in values])
1316-
result = DataFrame(
1324+
result = self.obj._constructor(
13171325
stacked_values, index=key_index, columns=index
13181326
)
13191327
else:
@@ -1330,15 +1338,17 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13301338
result.columns = index
13311339
elif isinstance(v, ABCSeries):
13321340
stacked_values = np.vstack([np.asarray(v) for v in values])
1333-
result = DataFrame(
1341+
result = self.obj._constructor(
13341342
stacked_values.T, index=v.index, columns=key_index
13351343
)
13361344
else:
13371345
# GH#1738: values is list of arrays of unequal lengths
13381346
# fall through to the outer else clause
13391347
# TODO: sure this is right? we used to do this
13401348
# after raising AttributeError above
1341-
return Series(values, index=key_index, name=self._selection_name)
1349+
return self.obj._constructor_sliced(
1350+
values, index=key_index, name=self._selection_name
1351+
)
13421352

13431353
# if we have date/time like in the original, then coerce dates
13441354
# as we are stacking can easily have object dtypes here
@@ -1355,7 +1365,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13551365
# self._selection_name not passed through to Series as the
13561366
# result should not take the name of original selection
13571367
# of columns
1358-
return Series(values, index=key_index)
1368+
return self.obj._constructor_sliced(values, index=key_index)
13591369

13601370
else:
13611371
# Handle cases like BinGrouper
@@ -1385,7 +1395,9 @@ def _transform_general(
13851395
if cache_key not in NUMBA_FUNC_CACHE:
13861396
NUMBA_FUNC_CACHE[cache_key] = numba_func
13871397
# Return the result as a DataFrame for concatenation later
1388-
res = DataFrame(res, index=group.index, columns=group.columns)
1398+
res = self.obj._constructor(
1399+
res, index=group.index, columns=group.columns
1400+
)
13891401
else:
13901402
# Try slow path and fast path.
13911403
try:
@@ -1408,7 +1420,7 @@ def _transform_general(
14081420
r.columns = group.columns
14091421
r.index = group.index
14101422
else:
1411-
r = DataFrame(
1423+
r = self.obj._constructor(
14121424
np.concatenate([res.values] * len(group.index)).reshape(
14131425
group.shape
14141426
),
@@ -1484,7 +1496,9 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
14841496
res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm)
14851497
output.append(res)
14861498

1487-
return DataFrame._from_arrays(output, columns=result.columns, index=obj.index)
1499+
return self.obj._constructor._from_arrays(
1500+
output, columns=result.columns, index=obj.index
1501+
)
14881502

14891503
def _define_paths(self, func, *args, **kwargs):
14901504
if isinstance(func, str):
@@ -1546,7 +1560,7 @@ def _transform_item_by_item(self, obj: DataFrame, wrapper) -> DataFrame:
15461560
if len(output) < len(obj.columns):
15471561
columns = columns.take(inds)
15481562

1549-
return DataFrame(output, index=obj.index, columns=columns)
1563+
return self.obj._constructor(output, index=obj.index, columns=columns)
15501564

15511565
def filter(self, func, dropna=True, *args, **kwargs):
15521566
"""
@@ -1661,9 +1675,11 @@ def _wrap_frame_output(self, result, obj) -> DataFrame:
16611675
result_index = self.grouper.levels[0]
16621676

16631677
if self.axis == 0:
1664-
return DataFrame(result, index=obj.columns, columns=result_index).T
1678+
return self.obj._constructor(
1679+
result, index=obj.columns, columns=result_index
1680+
).T
16651681
else:
1666-
return DataFrame(result, index=obj.index, columns=result_index)
1682+
return self.obj._constructor(result, index=obj.index, columns=result_index)
16671683

16681684
def _get_data_to_aggregate(self) -> BlockManager:
16691685
obj = self._obj_with_exclusions
@@ -1707,7 +1723,7 @@ def _wrap_aggregated_output(
17071723
indexed_output = {key.position: val for key, val in output.items()}
17081724
columns = Index(key.label for key in output)
17091725

1710-
result = DataFrame(indexed_output)
1726+
result = self.obj._constructor(indexed_output)
17111727
result.columns = columns
17121728

17131729
if not self.as_index:
@@ -1740,7 +1756,7 @@ def _wrap_transformed_output(
17401756
indexed_output = {key.position: val for key, val in output.items()}
17411757
columns = Index(key.label for key in output)
17421758

1743-
result = DataFrame(indexed_output)
1759+
result = self.obj._constructor(indexed_output)
17441760
result.columns = columns
17451761
result.index = self.obj.index
17461762

@@ -1750,14 +1766,14 @@ def _wrap_agged_blocks(self, blocks: "Sequence[Block]", items: Index) -> DataFra
17501766
if not self.as_index:
17511767
index = np.arange(blocks[0].values.shape[-1])
17521768
mgr = BlockManager(blocks, axes=[items, index])
1753-
result = DataFrame(mgr)
1769+
result = self.obj._constructor(mgr)
17541770

17551771
self._insert_inaxis_grouper_inplace(result)
17561772
result = result._consolidate()
17571773
else:
17581774
index = self.grouper.result_index
17591775
mgr = BlockManager(blocks, axes=[items, index])
1760-
result = DataFrame(mgr)
1776+
result = self.obj._constructor(mgr)
17611777

17621778
if self.axis == 1:
17631779
result = result.T

pandas/core/groupby/groupby.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,14 @@ class GroupBy(_GroupBy[FrameOrSeries]):
11851185
more
11861186
"""
11871187

1188+
@property
1189+
def _obj_1d_constructor(self) -> Type["Series"]:
1190+
# GH28330 preserve subclassed Series/DataFrames
1191+
if isinstance(self.obj, DataFrame):
1192+
return self.obj._constructor_sliced
1193+
assert isinstance(self.obj, Series)
1194+
return self.obj._constructor
1195+
11881196
def _bool_agg(self, val_test, skipna):
11891197
"""
11901198
Shared func to call any / all Cython GroupBy implementations.
@@ -1423,8 +1431,11 @@ def size(self):
14231431
"""
14241432
result = self.grouper.size()
14251433

1426-
if isinstance(self.obj, Series):
1427-
result.name = self.obj.name
1434+
# GH28330 preserve subclassed Series/DataFrames through calls
1435+
if issubclass(self.obj._constructor, Series):
1436+
result = self._obj_1d_constructor(result, name=self.obj.name)
1437+
else:
1438+
result = self._obj_1d_constructor(result)
14281439
return self._reindex_output(result, fill_value=0)
14291440

14301441
@classmethod
@@ -2110,7 +2121,7 @@ def ngroup(self, ascending: bool = True):
21102121
"""
21112122
with _group_selection_context(self):
21122123
index = self._selected_obj.index
2113-
result = Series(self.grouper.group_info[0], index)
2124+
result = self._obj_1d_constructor(self.grouper.group_info[0], index)
21142125
if not ascending:
21152126
result = self.ngroups - 1 - result
21162127
return result
@@ -2172,7 +2183,7 @@ def cumcount(self, ascending: bool = True):
21722183
with _group_selection_context(self):
21732184
index = self._selected_obj.index
21742185
cumcounts = self._cumcount_array(ascending=ascending)
2175-
return Series(cumcounts, index)
2186+
return self._obj_1d_constructor(cumcounts, index)
21762187

21772188
@Substitution(name="groupby")
21782189
@Appender(_common_see_also)

pandas/core/reshape/concat.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,9 @@ def get_result(self):
469469
# combine as columns in a frame
470470
else:
471471
data = dict(zip(range(len(self.objs)), self.objs))
472-
cons = DataFrame
472+
473+
# GH28330 Preserves subclassed objects through concat
474+
cons = self.objs[0]._constructor_expanddim
473475

474476
index, columns = self.new_axes
475477
df = cons(data, index=index)

pandas/tests/frame/test_subclass.py

+14
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,17 @@ def test_asof(self):
682682

683683
result = df.asof("1989-12-31")
684684
assert isinstance(result, tm.SubclassedSeries)
685+
686+
def test_idxmin_preserves_subclass(self):
687+
# GH 28330
688+
689+
df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
690+
result = df.idxmin()
691+
assert isinstance(result, tm.SubclassedSeries)
692+
693+
def test_idxmax_preserves_subclass(self):
694+
# GH 28330
695+
696+
df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
697+
result = df.idxmax()
698+
assert isinstance(result, tm.SubclassedSeries)

0 commit comments

Comments
 (0)