1
1
from __future__ import annotations
2
2
3
3
import itertools
4
- from typing import TYPE_CHECKING
4
+ from typing import (
5
+ TYPE_CHECKING ,
6
+ cast ,
7
+ )
5
8
import warnings
6
9
7
10
import numpy as np
@@ -452,7 +455,7 @@ def _unstack_multiple(data, clocs, fill_value=None):
452
455
return unstacked
453
456
454
457
455
- def unstack (obj , level , fill_value = None ):
458
+ def unstack (obj : Series | DataFrame , level , fill_value = None ):
456
459
457
460
if isinstance (level , (tuple , list )):
458
461
if len (level ) != 1 :
@@ -489,19 +492,20 @@ def unstack(obj, level, fill_value=None):
489
492
)
490
493
491
494
492
- def _unstack_frame (obj , level , fill_value = None ):
495
+ def _unstack_frame (obj : DataFrame , level , fill_value = None ):
496
+ assert isinstance (obj .index , MultiIndex ) # checked by caller
497
+ unstacker = _Unstacker (obj .index , level = level , constructor = obj ._constructor )
498
+
493
499
if not obj ._can_fast_transpose :
494
- unstacker = _Unstacker (obj .index , level = level )
495
500
mgr = obj ._mgr .unstack (unstacker , fill_value = fill_value )
496
501
return obj ._constructor (mgr )
497
502
else :
498
- unstacker = _Unstacker (obj .index , level = level , constructor = obj ._constructor )
499
503
return unstacker .get_result (
500
504
obj ._values , value_columns = obj .columns , fill_value = fill_value
501
505
)
502
506
503
507
504
- def _unstack_extension_series (series , level , fill_value ):
508
+ def _unstack_extension_series (series : Series , level , fill_value ) -> DataFrame :
505
509
"""
506
510
Unstack an ExtensionArray-backed Series.
507
511
@@ -534,14 +538,14 @@ def _unstack_extension_series(series, level, fill_value):
534
538
return result
535
539
536
540
537
- def stack (frame , level = - 1 , dropna = True ):
541
+ def stack (frame : DataFrame , level = - 1 , dropna : bool = True ):
538
542
"""
539
543
Convert DataFrame to Series with multi-level Index. Columns become the
540
544
second level of the resulting hierarchical index
541
545
542
546
Returns
543
547
-------
544
- stacked : Series
548
+ stacked : Series or DataFrame
545
549
"""
546
550
547
551
def factorize (index ):
@@ -676,8 +680,10 @@ def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex:
676
680
)
677
681
678
682
679
- def _stack_multi_columns (frame , level_num = - 1 , dropna = True ):
680
- def _convert_level_number (level_num : int , columns ):
683
+ def _stack_multi_columns (
684
+ frame : DataFrame , level_num : int = - 1 , dropna : bool = True
685
+ ) -> DataFrame :
686
+ def _convert_level_number (level_num : int , columns : Index ):
681
687
"""
682
688
Logic for converting the level number to something we can safely pass
683
689
to swaplevel.
@@ -690,32 +696,36 @@ def _convert_level_number(level_num: int, columns):
690
696
691
697
return level_num
692
698
693
- this = frame .copy ()
699
+ this = frame .copy (deep = False )
700
+ mi_cols = this .columns # cast(MultiIndex, this.columns)
701
+ assert isinstance (mi_cols , MultiIndex ) # caller is responsible
694
702
695
703
# this makes life much simpler
696
- if level_num != frame . columns .nlevels - 1 :
704
+ if level_num != mi_cols .nlevels - 1 :
697
705
# roll levels to put selected level at end
698
- roll_columns = this . columns
699
- for i in range (level_num , frame . columns .nlevels - 1 ):
706
+ roll_columns = mi_cols
707
+ for i in range (level_num , mi_cols .nlevels - 1 ):
700
708
# Need to check if the ints conflict with level names
701
709
lev1 = _convert_level_number (i , roll_columns )
702
710
lev2 = _convert_level_number (i + 1 , roll_columns )
703
711
roll_columns = roll_columns .swaplevel (lev1 , lev2 )
704
- this .columns = roll_columns
712
+ this .columns = mi_cols = roll_columns
705
713
706
- if not this . columns ._is_lexsorted ():
714
+ if not mi_cols ._is_lexsorted ():
707
715
# Workaround the edge case where 0 is one of the column names,
708
716
# which interferes with trying to sort based on the first
709
717
# level
710
- level_to_sort = _convert_level_number (0 , this . columns )
718
+ level_to_sort = _convert_level_number (0 , mi_cols )
711
719
this = this .sort_index (level = level_to_sort , axis = 1 )
720
+ mi_cols = this .columns
712
721
713
- new_columns = _stack_multi_column_index (this .columns )
722
+ mi_cols = cast (MultiIndex , mi_cols )
723
+ new_columns = _stack_multi_column_index (mi_cols )
714
724
715
725
# time to ravel the values
716
726
new_data = {}
717
- level_vals = this . columns .levels [- 1 ]
718
- level_codes = sorted (set (this . columns .codes [- 1 ]))
727
+ level_vals = mi_cols .levels [- 1 ]
728
+ level_codes = sorted (set (mi_cols .codes [- 1 ]))
719
729
level_vals_nan = level_vals .insert (len (level_vals ), None )
720
730
721
731
level_vals_used = np .take (level_vals_nan , level_codes )
0 commit comments