Skip to content

Commit c554784

Browse files
committed
BUG: Stack/unstack do not return subclassed objects (GH15563)
1 parent 2dac793 commit c554784

File tree

5 files changed

+318
-22
lines changed

5 files changed

+318
-22
lines changed

doc/source/whatsnew/v0.23.0.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ Reshaping
361361
- Bug in :func:`Series.rank` where ``Series`` containing ``NaT`` modifies the ``Series`` inplace (:issue:`18521`)
362362
- Bug in :func:`cut` which fails when using readonly arrays (:issue:`18773`)
363363
- Bug in :func:`Dataframe.pivot_table` which fails when the ``aggfunc`` arg is of type string. The behavior is now consistent with other methods like ``agg`` and ``apply`` (:issue:`18713`)
364-
364+
- Bug in :func:`DataFrame.stack`, :func:`DataFrame.unstack`, :func:`Series.unstack` which were not returning subclasses (:issue:`15563`)
365+
-
365366

366367
Numeric
367368
^^^^^^^

pandas/core/reshape/melt.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def melt(frame, id_vars=None, value_vars=None, var_name=None,
8080
mdata[col] = np.asanyarray(frame.columns
8181
._get_level_values(i)).repeat(N)
8282

83-
from pandas import DataFrame
84-
return DataFrame(mdata, columns=mcolumns)
83+
return frame._constructor(mdata, columns=mcolumns)
8584

8685

8786
def lreshape(data, groups, dropna=True, label=None):
@@ -152,8 +151,7 @@ def lreshape(data, groups, dropna=True, label=None):
152151
if not mask.all():
153152
mdata = {k: v[mask] for k, v in compat.iteritems(mdata)}
154153

155-
from pandas import DataFrame
156-
return DataFrame(mdata, columns=id_cols + pivot_cols)
154+
return data._constructor(mdata, columns=id_cols + pivot_cols)
157155

158156

159157
def wide_to_long(df, stubnames, i, j, sep="", suffix=r'\d+'):

pandas/core/reshape/reshape.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,23 @@ class _Unstacker(object):
3737
3838
Parameters
3939
----------
40+
values : ndarray
41+
Values of DataFrame to "Unstack"
42+
index : object
43+
Pandas ``Index``
4044
level : int or str, default last level
4145
Level to "unstack". Accepts a name for the level.
46+
value_columns : Index, optional
47+
Pandas ``Index`` or ``MultiIndex`` object if unstacking a DataFrame
48+
fill_value : scalar, optional
49+
Default value to fill in missing values if subgroups do not have the
50+
same set of labels. By default, missing values will be replaced with
51+
the default fill value for that data type, NaN for float, NaT for
52+
datetimelike, etc. For integer types, by default data will converted to
53+
float and missing values will be set to NaN.
54+
constructor : object
55+
Pandas ``DataFrame`` or subclass used to create unstacked
56+
response. If None, DataFrame or SparseDataFrame will be used.
4257
4358
Examples
4459
--------
@@ -69,7 +84,7 @@ class _Unstacker(object):
6984
"""
7085

7186
def __init__(self, values, index, level=-1, value_columns=None,
72-
fill_value=None):
87+
fill_value=None, constructor=None):
7388

7489
self.is_categorical = None
7590
self.is_sparse = is_sparse(values)
@@ -86,6 +101,14 @@ def __init__(self, values, index, level=-1, value_columns=None,
86101
self.value_columns = value_columns
87102
self.fill_value = fill_value
88103

104+
if constructor is None:
105+
if self.is_sparse:
106+
self.constructor = SparseDataFrame
107+
else:
108+
self.constructor = DataFrame
109+
else:
110+
self.constructor = constructor
111+
89112
if value_columns is None and values.shape[1] != 1: # pragma: no cover
90113
raise ValueError('must pass column labels for multi-column data')
91114

@@ -173,8 +196,7 @@ def get_result(self):
173196
ordered=ordered)
174197
for i in range(values.shape[-1])]
175198

176-
klass = SparseDataFrame if self.is_sparse else DataFrame
177-
return klass(values, index=index, columns=columns)
199+
return self.constructor(values, index=index, columns=columns)
178200

179201
def get_new_values(self):
180202
values = self.values
@@ -374,8 +396,9 @@ def pivot(self, index=None, columns=None, values=None):
374396
index = self.index
375397
else:
376398
index = self[index]
377-
indexed = Series(self[values].values,
378-
index=MultiIndex.from_arrays([index, self[columns]]))
399+
indexed = self._constructor_sliced(
400+
self[values].values,
401+
index=MultiIndex.from_arrays([index, self[columns]]))
379402
return indexed.unstack(columns)
380403

381404

@@ -461,7 +484,8 @@ def unstack(obj, level, fill_value=None):
461484
return obj.T.stack(dropna=False)
462485
else:
463486
unstacker = _Unstacker(obj.values, obj.index, level=level,
464-
fill_value=fill_value)
487+
fill_value=fill_value,
488+
constructor=obj._constructor_expanddim)
465489
return unstacker.get_result()
466490

467491

@@ -470,12 +494,12 @@ def _unstack_frame(obj, level, fill_value=None):
470494
unstacker = partial(_Unstacker, index=obj.index,
471495
level=level, fill_value=fill_value)
472496
blocks = obj._data.unstack(unstacker)
473-
klass = type(obj)
474-
return klass(blocks)
497+
return obj._constructor(blocks)
475498
else:
476499
unstacker = _Unstacker(obj.values, obj.index, level=level,
477500
value_columns=obj.columns,
478-
fill_value=fill_value)
501+
fill_value=fill_value,
502+
constructor=obj._constructor)
479503
return unstacker.get_result()
480504

481505

@@ -528,8 +552,7 @@ def factorize(index):
528552
new_values = new_values[mask]
529553
new_index = new_index[mask]
530554

531-
klass = type(frame)._constructor_sliced
532-
return klass(new_values, index=new_index)
555+
return frame._constructor_sliced(new_values, index=new_index)
533556

534557

535558
def stack_multiple(frame, level, dropna=True):
@@ -675,7 +698,7 @@ def _convert_level_number(level_num, columns):
675698
new_index = MultiIndex(levels=new_levels, labels=new_labels,
676699
names=new_names, verify_integrity=False)
677700

678-
result = DataFrame(new_data, index=new_index, columns=new_columns)
701+
result = frame._constructor(new_data, index=new_index, columns=new_columns)
679702

680703
# more efficient way to go about this? can do the whole masking biz but
681704
# will only save a small amount of time...

0 commit comments

Comments
 (0)