6
6
TYPE_CHECKING ,
7
7
Any ,
8
8
Callable ,
9
+ Collection ,
9
10
Hashable ,
10
11
Iterable ,
11
12
List ,
98
99
if TYPE_CHECKING :
99
100
from pandas import (
100
101
CategoricalIndex ,
102
+ DataFrame ,
101
103
Series ,
102
104
)
103
105
@@ -323,7 +325,7 @@ def __new__(
323
325
if len (levels ) == 0 :
324
326
raise ValueError ("Must pass non-zero number of levels/codes" )
325
327
326
- result = object .__new__ (MultiIndex )
328
+ result = object .__new__ (cls )
327
329
result ._cache = {}
328
330
329
331
# we've already validated levels and codes, so shortcut here
@@ -503,7 +505,7 @@ def from_arrays(cls, arrays, sortorder=None, names=lib.no_default) -> MultiIndex
503
505
@names_compat
504
506
def from_tuples (
505
507
cls ,
506
- tuples ,
508
+ tuples : Iterable [ Tuple [ Hashable , ...]] ,
507
509
sortorder : Optional [int ] = None ,
508
510
names : Optional [Sequence [Hashable ]] = None ,
509
511
) -> MultiIndex :
@@ -546,6 +548,7 @@ def from_tuples(
546
548
raise TypeError ("Input must be a list / sequence of tuple-likes." )
547
549
elif is_iterator (tuples ):
548
550
tuples = list (tuples )
551
+ tuples = cast (Collection [Tuple [Hashable , ...]], tuples )
549
552
550
553
arrays : List [Sequence [Hashable ]]
551
554
if len (tuples ) == 0 :
@@ -560,7 +563,8 @@ def from_tuples(
560
563
elif isinstance (tuples , list ):
561
564
arrays = list (lib .to_object_array_tuples (tuples ).T )
562
565
else :
563
- arrays = zip (* tuples )
566
+ arrs = zip (* tuples )
567
+ arrays = cast (List [Sequence [Hashable ]], arrs )
564
568
565
569
return cls .from_arrays (arrays , sortorder = sortorder , names = names )
566
570
@@ -626,7 +630,7 @@ def from_product(
626
630
return cls (levels , codes , sortorder = sortorder , names = names )
627
631
628
632
@classmethod
629
- def from_frame (cls , df , sortorder = None , names = None ) -> MultiIndex :
633
+ def from_frame (cls , df : DataFrame , sortorder = None , names = None ) -> MultiIndex :
630
634
"""
631
635
Make a MultiIndex from a DataFrame.
632
636
@@ -762,7 +766,7 @@ def __len__(self) -> int:
762
766
# Levels Methods
763
767
764
768
@cache_readonly
765
- def levels (self ):
769
+ def levels (self ) -> FrozenList :
766
770
# Use cache_readonly to ensure that self.get_locs doesn't repeatedly
767
771
# create new IndexEngine
768
772
# https://github.com/pandas-dev/pandas/issues/31648
@@ -1293,7 +1297,7 @@ def _formatter_func(self, tup):
1293
1297
formatter_funcs = [level ._formatter_func for level in self .levels ]
1294
1298
return tuple (func (val ) for func , val in zip (formatter_funcs , tup ))
1295
1299
1296
- def _format_data (self , name = None ):
1300
+ def _format_data (self , name = None ) -> str :
1297
1301
"""
1298
1302
Return the formatted data as a unicode string
1299
1303
"""
@@ -1419,10 +1423,10 @@ def format(
1419
1423
# --------------------------------------------------------------------
1420
1424
# Names Methods
1421
1425
1422
- def _get_names (self ):
1426
+ def _get_names (self ) -> FrozenList :
1423
1427
return FrozenList (self ._names )
1424
1428
1425
- def _set_names (self , names , level = None , validate = True ):
1429
+ def _set_names (self , names , level = None , validate : bool = True ):
1426
1430
"""
1427
1431
Set new names on index. Each name has to be a hashable type.
1428
1432
@@ -1433,7 +1437,7 @@ def _set_names(self, names, level=None, validate=True):
1433
1437
level : int, level name, or sequence of int/level names (default None)
1434
1438
If the index is a MultiIndex (hierarchical), level(s) to set (None
1435
1439
for all levels). Otherwise level must be None
1436
- validate : boolean , default True
1440
+ validate : bool , default True
1437
1441
validate that the names match level lengths
1438
1442
1439
1443
Raises
@@ -1712,7 +1716,7 @@ def unique(self, level=None):
1712
1716
level = self ._get_level_number (level )
1713
1717
return self ._get_level_values (level = level , unique = True )
1714
1718
1715
- def to_frame (self , index = True , name = None ):
1719
+ def to_frame (self , index = True , name = None ) -> DataFrame :
1716
1720
"""
1717
1721
Create a DataFrame with the levels of the MultiIndex as columns.
1718
1722
@@ -2109,8 +2113,8 @@ def take(
2109
2113
2110
2114
na_value = - 1
2111
2115
2116
+ taken = [lab .take (indices ) for lab in self .codes ]
2112
2117
if allow_fill :
2113
- taken = [lab .take (indices ) for lab in self .codes ]
2114
2118
mask = indices == - 1
2115
2119
if mask .any ():
2116
2120
masked = []
@@ -2119,8 +2123,6 @@ def take(
2119
2123
label_values [mask ] = na_value
2120
2124
masked .append (np .asarray (label_values ))
2121
2125
taken = masked
2122
- else :
2123
- taken = [lab .take (indices ) for lab in self .codes ]
2124
2126
2125
2127
return MultiIndex (
2126
2128
levels = self .levels , codes = taken , names = self .names , verify_integrity = False
@@ -2644,7 +2646,9 @@ def _get_partial_string_timestamp_match_key(self, key):
2644
2646
2645
2647
return key
2646
2648
2647
- def _get_indexer (self , target : Index , method = None , limit = None , tolerance = None ):
2649
+ def _get_indexer (
2650
+ self , target : Index , method = None , limit = None , tolerance = None
2651
+ ) -> np .ndarray :
2648
2652
2649
2653
# empty indexer
2650
2654
if not len (target ):
@@ -3521,7 +3525,7 @@ def equals(self, other: object) -> bool:
3521
3525
3522
3526
return True
3523
3527
3524
- def equal_levels (self , other ) -> bool :
3528
+ def equal_levels (self , other : MultiIndex ) -> bool :
3525
3529
"""
3526
3530
Return True if the levels of both MultiIndex objects are the same
3527
3531
@@ -3537,7 +3541,7 @@ def equal_levels(self, other) -> bool:
3537
3541
# --------------------------------------------------------------------
3538
3542
# Set Methods
3539
3543
3540
- def _union (self , other , sort ):
3544
+ def _union (self , other , sort ) -> MultiIndex :
3541
3545
other , result_names = self ._convert_can_do_setop (other )
3542
3546
3543
3547
# We could get here with CategoricalIndex other
@@ -3579,7 +3583,7 @@ def _maybe_match_names(self, other):
3579
3583
names .append (None )
3580
3584
return names
3581
3585
3582
- def _intersection (self , other , sort = False ):
3586
+ def _intersection (self , other , sort = False ) -> MultiIndex :
3583
3587
other , result_names = self ._convert_can_do_setop (other )
3584
3588
3585
3589
lvals = self ._values
0 commit comments