@@ -37,8 +37,23 @@ class _Unstacker(object):
37
37
38
38
Parameters
39
39
----------
40
+ values : ndarray
41
+ Values of DataFrame to "Unstack"
42
+ index : object
43
+ Pandas ``Index``
40
44
level : int or str, default last level
41
45
Level to "unstack". Accepts a name for the level.
46
+ value_columns : Index, optional
47
+ Pandas ``Index`` or ``MultiIndex`` object if unstacking a DataFrame
48
+ fill_value : scalar, optional
49
+ Default value to fill in missing values if subgroups do not have the
50
+ same set of labels. By default, missing values will be replaced with
51
+ the default fill value for that data type, NaN for float, NaT for
52
+ datetimelike, etc. For integer types, by default data will converted to
53
+ float and missing values will be set to NaN.
54
+ constructor : object
55
+ Pandas ``DataFrame`` or subclass used to create unstacked
56
+ response. If None, DataFrame or SparseDataFrame will be used.
42
57
43
58
Examples
44
59
--------
@@ -69,7 +84,7 @@ class _Unstacker(object):
69
84
"""
70
85
71
86
def __init__ (self , values , index , level = - 1 , value_columns = None ,
72
- fill_value = None ):
87
+ fill_value = None , constructor = None ):
73
88
74
89
self .is_categorical = None
75
90
self .is_sparse = is_sparse (values )
@@ -86,6 +101,14 @@ def __init__(self, values, index, level=-1, value_columns=None,
86
101
self .value_columns = value_columns
87
102
self .fill_value = fill_value
88
103
104
+ if constructor is None :
105
+ if self .is_sparse :
106
+ self .constructor = SparseDataFrame
107
+ else :
108
+ self .constructor = DataFrame
109
+ else :
110
+ self .constructor = constructor
111
+
89
112
if value_columns is None and values .shape [1 ] != 1 : # pragma: no cover
90
113
raise ValueError ('must pass column labels for multi-column data' )
91
114
@@ -173,8 +196,7 @@ def get_result(self):
173
196
ordered = ordered )
174
197
for i in range (values .shape [- 1 ])]
175
198
176
- klass = SparseDataFrame if self .is_sparse else DataFrame
177
- return klass (values , index = index , columns = columns )
199
+ return self .constructor (values , index = index , columns = columns )
178
200
179
201
def get_new_values (self ):
180
202
values = self .values
@@ -374,8 +396,9 @@ def pivot(self, index=None, columns=None, values=None):
374
396
index = self .index
375
397
else :
376
398
index = self [index ]
377
- indexed = Series (self [values ].values ,
378
- index = MultiIndex .from_arrays ([index , self [columns ]]))
399
+ indexed = self ._constructor_sliced (
400
+ self [values ].values ,
401
+ index = MultiIndex .from_arrays ([index , self [columns ]]))
379
402
return indexed .unstack (columns )
380
403
381
404
@@ -461,7 +484,8 @@ def unstack(obj, level, fill_value=None):
461
484
return obj .T .stack (dropna = False )
462
485
else :
463
486
unstacker = _Unstacker (obj .values , obj .index , level = level ,
464
- fill_value = fill_value )
487
+ fill_value = fill_value ,
488
+ constructor = obj ._constructor_expanddim )
465
489
return unstacker .get_result ()
466
490
467
491
@@ -470,12 +494,12 @@ def _unstack_frame(obj, level, fill_value=None):
470
494
unstacker = partial (_Unstacker , index = obj .index ,
471
495
level = level , fill_value = fill_value )
472
496
blocks = obj ._data .unstack (unstacker )
473
- klass = type (obj )
474
- return klass (blocks )
497
+ return obj ._constructor (blocks )
475
498
else :
476
499
unstacker = _Unstacker (obj .values , obj .index , level = level ,
477
500
value_columns = obj .columns ,
478
- fill_value = fill_value )
501
+ fill_value = fill_value ,
502
+ constructor = obj ._constructor )
479
503
return unstacker .get_result ()
480
504
481
505
@@ -528,8 +552,7 @@ def factorize(index):
528
552
new_values = new_values [mask ]
529
553
new_index = new_index [mask ]
530
554
531
- klass = type (frame )._constructor_sliced
532
- return klass (new_values , index = new_index )
555
+ return frame ._constructor_sliced (new_values , index = new_index )
533
556
534
557
535
558
def stack_multiple (frame , level , dropna = True ):
@@ -676,7 +699,7 @@ def _convert_level_number(level_num, columns):
676
699
new_index = MultiIndex (levels = new_levels , labels = new_labels ,
677
700
names = new_names , verify_integrity = False )
678
701
679
- result = DataFrame (new_data , index = new_index , columns = new_columns )
702
+ result = frame . _constructor (new_data , index = new_index , columns = new_columns )
680
703
681
704
# more efficient way to go about this? can do the whole masking biz but
682
705
# will only save a small amount of time...
0 commit comments