Skip to content

Commit 41b30df

Browse files
committed
ENH: DataFrame.stack() with 'level' a set or list of sets
1 parent 7d13fdd commit 41b30df

File tree

3 files changed

+257
-57
lines changed

3 files changed

+257
-57
lines changed

pandas/core/index.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,8 @@ def _validate_index_level(self, level):
820820
% (level, self.name))
821821

822822
def _get_level_number(self, level):
823+
if isinstance(level, set):
824+
return set(self._get_level_number(lev) for lev in level)
823825
self._validate_index_level(level)
824826
return 0
825827

@@ -3157,6 +3159,8 @@ def _from_elements(values, labels=None, levels=None, names=None,
31573159
return MultiIndex(levels, labels, names, sortorder=sortorder)
31583160

31593161
def _get_level_number(self, level):
3162+
if isinstance(level, set):
3163+
return set(self._get_level_number(lev) for lev in level)
31603164
try:
31613165
count = self.names.count(level)
31623166
if count > 1:
@@ -4850,7 +4854,7 @@ def _trim_front(strings):
48504854

48514855

48524856
def _sanitize_and_check(indexes):
4853-
kinds = list(set([type(index) for index in indexes]))
4857+
kinds = list(set(type(index) for index in indexes))
48544858

48554859
if list in kinds:
48564860
if len(kinds) > 1:
@@ -4871,9 +4875,9 @@ def _get_consensus_names(indexes):
48714875

48724876
# find the non-none names, need to tupleify to make
48734877
# the set hashable, then reverse on return
4874-
consensus_names = set([
4878+
consensus_names = set(
48754879
tuple(i.names) for i in indexes if all(n is not None for n in i.names)
4876-
])
4880+
)
48774881
if len(consensus_names) == 1:
48784882
return list(list(consensus_names)[0])
48794883
return [None] * indexes[0].nlevels

pandas/core/reshape.py

+80-54
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,14 @@ def get_compressed_ids(labels, sizes):
508508
return comp_index, obs_ids
509509

510510

511+
def _iterate_through_set(x):
512+
if isinstance(x, set):
513+
for y in x:
514+
yield y
515+
else:
516+
yield x
517+
518+
511519
def stack(frame, level=-1, dropna=True):
512520
"""
513521
Convert DataFrame to Series with multi-level Index. Columns become the
@@ -517,19 +525,18 @@ def stack(frame, level=-1, dropna=True):
517525
-------
518526
stacked : Series
519527
"""
520-
N, K = frame.shape
521528
if isinstance(frame.columns, MultiIndex):
522-
if frame.columns._reference_duplicate_name(level):
529+
if any(frame.columns._reference_duplicate_name(lev)
530+
for lev in _iterate_through_set(level)):
523531
msg = ("Ambiguous reference to {0}. The column "
524532
"names are not unique.".format(level))
525533
raise ValueError(msg)
526-
527-
# Will also convert negative level numbers and check if out of bounds.
528-
level_num = frame.columns._get_level_number(level)
529-
530-
if isinstance(frame.columns, MultiIndex):
534+
# Will also convert negative level numbers and check if out of bounds.
535+
level_num = frame.columns._get_level_number(level)
531536
return _stack_multi_columns(frame, level_num=level_num, dropna=dropna)
532-
elif isinstance(frame.index, MultiIndex):
537+
538+
N, K = frame.shape
539+
if isinstance(frame.index, MultiIndex):
533540
new_levels = list(frame.index.levels)
534541
new_levels.append(frame.columns)
535542

@@ -559,13 +566,13 @@ def stack(frame, level=-1, dropna=True):
559566
def stack_multiple(frame, level, dropna=True):
560567
# If all passed levels match up to column names, no
561568
# ambiguity about what to do
562-
if all(lev in frame.columns.names for lev in level):
569+
if all(lev in frame.columns.names for levl in level for lev in _iterate_through_set(levl)):
563570
result = frame
564571
for lev in level:
565572
result = stack(result, lev, dropna=dropna)
566573

567574
# Otherwise, level numbers may change as each successive level is stacked
568-
elif all(isinstance(lev, int) for lev in level):
575+
elif all(isinstance(lev, int) for levl in level for lev in _iterate_through_set(levl)):
569576
# As each stack is done, the level numbers decrease, so we need
570577
# to account for that when level is a sequence of ints
571578
result = frame
@@ -576,16 +583,19 @@ def stack_multiple(frame, level, dropna=True):
576583
# Can't iterate directly through level as we might need to change
577584
# values as we go
578585
for index in range(len(level)):
579-
lev = level[index]
580-
result = stack(result, lev, dropna=dropna)
586+
levl = level[index]
587+
result = stack(result, levl, dropna=dropna)
581588
# Decrement all level numbers greater than current, as these
582-
# have now shifted down by one
589+
# have now shifted down
583590
updated_level = []
584591
for other in level:
585-
if other > lev:
586-
updated_level.append(other - 1)
592+
if isinstance(other, set):
593+
updated_level.append(set((othr - sum((othr > lev)
594+
for lev in _iterate_through_set(levl)))
595+
for othr in other))
587596
else:
588-
updated_level.append(other)
597+
updated_level.append(other - sum((other > lev)
598+
for lev in _iterate_through_set(levl)))
589599
level = updated_level
590600

591601
else:
@@ -616,85 +626,101 @@ def _convert_level_number(level_num, columns):
616626
this = frame.copy()
617627

618628
# this makes life much simpler
619-
if level_num != frame.columns.nlevels - 1:
620-
# roll levels to put selected level at end
621-
roll_columns = this.columns
622-
for i in range(level_num, frame.columns.nlevels - 1):
629+
# roll levels to put selected level(s) at end
630+
level_nums = level_num if isinstance(level_num, set) else set([level_num])
631+
roll_columns = this.columns
632+
for j, level_num in enumerate(sorted(level_nums, reverse=True)):
633+
for i in range(level_num, frame.columns.nlevels - (j + 1)):
623634
# Need to check if the ints conflict with level names
624635
lev1 = _convert_level_number(i, roll_columns)
625636
lev2 = _convert_level_number(i + 1, roll_columns)
626637
roll_columns = roll_columns.swaplevel(lev1, lev2)
627-
this.columns = roll_columns
638+
this.columns = roll_columns
628639

629640
if not this.columns.is_lexsorted():
630641
# Workaround the edge case where 0 is one of the column names,
631-
# which interferes with trying to sort based on the first
632-
# level
642+
# which interferes with trying to sort based on the first level
633643
level_to_sort = _convert_level_number(0, this.columns)
634644
this = this.sortlevel(level_to_sort, axis=1)
635645

636-
# tuple list excluding level for grouping columns
637-
if len(frame.columns.levels) > 2:
646+
num_levels_to_stack = len(level_nums)
647+
level_vals = this.columns.levels[-num_levels_to_stack:]
648+
level_labels = sorted(set(zip(*this.columns.labels[-num_levels_to_stack:])))
649+
level_vals_used = MultiIndex.from_tuples([tuple(level_vals[i][lab] for i, lab in enumerate(label))
650+
for label in level_labels],
651+
names=this.columns.names[-num_levels_to_stack:])
652+
levsize = len(level_labels)
653+
654+
# construct new_index
655+
N = len(this)
656+
if isinstance(this.index, MultiIndex):
657+
new_levels = list(this.index.levels)
658+
new_names = list(this.index.names)
659+
new_labels = [lab.repeat(levsize) for lab in this.index.labels]
660+
else:
661+
new_levels = [this.index]
662+
new_labels = [np.arange(N).repeat(levsize)]
663+
new_names = [this.index.name] # something better?
664+
new_levels += level_vals
665+
new_labels += [np.tile(labels, N) for labels in zip(*level_labels)]
666+
new_names += level_vals_used.names
667+
new_index = MultiIndex(levels=new_levels, labels=new_labels,
668+
names=new_names, verify_integrity=False)
669+
670+
# if stacking all levels in columns, result will be a Series
671+
if len(frame.columns.levels) == num_levels_to_stack:
672+
new_data = frame.values.ravel()
673+
if dropna:
674+
mask = notnull(new_data)
675+
new_data = new_data[mask]
676+
new_index = new_index[mask]
677+
return Series(new_data, index=new_index)
678+
679+
# result will be a DataFrame
680+
681+
# construct new_columns
682+
if len(frame.columns.levels) > (num_levels_to_stack + 1):
683+
# result columns will be a MultiIndex
684+
# tuple list excluding level for grouping columns
638685
tuples = list(zip(*[
639686
lev.take(lab) for lev, lab in
640-
zip(this.columns.levels[:-1], this.columns.labels[:-1])
687+
zip(this.columns.levels[:-num_levels_to_stack],
688+
this.columns.labels[:-num_levels_to_stack])
641689
]))
642690
unique_groups = [key for key, _ in itertools.groupby(tuples)]
643-
new_names = this.columns.names[:-1]
691+
new_names = this.columns.names[:-num_levels_to_stack]
644692
new_columns = MultiIndex.from_tuples(unique_groups, names=new_names)
645693
else:
694+
# result columns will be an Index
646695
new_columns = unique_groups = this.columns.levels[0]
647696

648-
# time to ravel the values
697+
# construct new_data
649698
new_data = {}
650-
level_vals = this.columns.levels[-1]
651-
level_labels = sorted(set(this.columns.labels[-1]))
652-
level_vals_used = level_vals[level_labels]
653-
levsize = len(level_labels)
654699
drop_cols = []
655700
for key in unique_groups:
656701
loc = this.columns.get_loc(key)
657702
slice_len = loc.stop - loc.start
658703
# can make more efficient?
659-
660704
if slice_len == 0:
661705
drop_cols.append(key)
662706
continue
663707
elif slice_len != levsize:
664708
chunk = this.ix[:, this.columns[loc]]
665-
chunk.columns = level_vals.take(chunk.columns.labels[-1])
709+
chunk.columns = MultiIndex.from_arrays([vals.take(labels) for (vals, labels)
710+
in zip(level_vals, chunk.columns.labels[-num_levels_to_stack:])],
711+
names=chunk.columns.names[-num_levels_to_stack:])
666712
value_slice = chunk.reindex(columns=level_vals_used).values
667713
else:
668714
if frame._is_mixed_type:
669715
value_slice = this.ix[:, this.columns[loc]].values
670716
else:
671717
value_slice = this.values[:, loc]
672-
673718
new_data[key] = value_slice.ravel()
674719

675720
if len(drop_cols) > 0:
676721
new_columns = new_columns - drop_cols
677722

678-
N = len(this)
679-
680-
if isinstance(this.index, MultiIndex):
681-
new_levels = list(this.index.levels)
682-
new_names = list(this.index.names)
683-
new_labels = [lab.repeat(levsize) for lab in this.index.labels]
684-
else:
685-
new_levels = [this.index]
686-
new_labels = [np.arange(N).repeat(levsize)]
687-
new_names = [this.index.name] # something better?
688-
689-
new_levels.append(frame.columns.levels[level_num])
690-
new_labels.append(np.tile(level_labels, N))
691-
new_names.append(frame.columns.names[level_num])
692-
693-
new_index = MultiIndex(levels=new_levels, labels=new_labels,
694-
names=new_names, verify_integrity=False)
695-
696723
result = DataFrame(new_data, index=new_index, columns=new_columns)
697-
698724
# more efficient way to go about this? can do the whole masking biz but
699725
# will only save a small amount of time...
700726
if dropna:

0 commit comments

Comments
 (0)