10
10
import string
11
11
from typing import (
12
12
TYPE_CHECKING ,
13
+ Hashable ,
14
+ List ,
13
15
Optional ,
14
16
Tuple ,
15
17
cast ,
@@ -124,14 +126,13 @@ def merge(
124
126
merge .__doc__ = _merge_doc % "\n left : DataFrame"
125
127
126
128
127
- def _groupby_and_merge (by , on , left : DataFrame , right : DataFrame , merge_pieces ):
129
+ def _groupby_and_merge (by , left : DataFrame , right : DataFrame , merge_pieces ):
128
130
"""
129
131
groupby & merge; we are always performing a left-by type operation
130
132
131
133
Parameters
132
134
----------
133
135
by: field to group
134
- on: duplicates field
135
136
left: DataFrame
136
137
right: DataFrame
137
138
merge_pieces: function for merging
@@ -307,17 +308,15 @@ def _merger(x, y):
307
308
check = set (left_by ).difference (left .columns )
308
309
if len (check ) != 0 :
309
310
raise KeyError (f"{ check } not found in left columns" )
310
- result , _ = _groupby_and_merge (
311
- left_by , on , left , right , lambda x , y : _merger (x , y )
312
- )
311
+ result , _ = _groupby_and_merge (left_by , left , right , lambda x , y : _merger (x , y ))
313
312
elif right_by is not None :
314
313
if isinstance (right_by , str ):
315
314
right_by = [right_by ]
316
315
check = set (right_by ).difference (right .columns )
317
316
if len (check ) != 0 :
318
317
raise KeyError (f"{ check } not found in right columns" )
319
318
result , _ = _groupby_and_merge (
320
- right_by , on , right , left , lambda x , y : _merger (y , x )
319
+ right_by , right , left , lambda x , y : _merger (y , x )
321
320
)
322
321
else :
323
322
result = _merger (left , right )
@@ -708,7 +707,7 @@ def __init__(
708
707
if validate is not None :
709
708
self ._validate (validate )
710
709
711
- def get_result (self ):
710
+ def get_result (self ) -> DataFrame :
712
711
if self .indicator :
713
712
self .left , self .right = self ._indicator_pre_merge (self .left , self .right )
714
713
@@ -774,7 +773,7 @@ def _indicator_pre_merge(
774
773
775
774
return left , right
776
775
777
- def _indicator_post_merge (self , result ) :
776
+ def _indicator_post_merge (self , result : DataFrame ) -> DataFrame :
778
777
779
778
result ["_left_indicator" ] = result ["_left_indicator" ].fillna (0 )
780
779
result ["_right_indicator" ] = result ["_right_indicator" ].fillna (0 )
@@ -790,7 +789,7 @@ def _indicator_post_merge(self, result):
790
789
result = result .drop (labels = ["_left_indicator" , "_right_indicator" ], axis = 1 )
791
790
return result
792
791
793
- def _maybe_restore_index_levels (self , result ) :
792
+ def _maybe_restore_index_levels (self , result : DataFrame ) -> None :
794
793
"""
795
794
Restore index levels specified as `on` parameters
796
795
@@ -949,7 +948,6 @@ def _get_join_info(self):
949
948
self .left .index ,
950
949
self .right .index ,
951
950
left_indexer ,
952
- right_indexer ,
953
951
how = "right" ,
954
952
)
955
953
else :
@@ -961,7 +959,6 @@ def _get_join_info(self):
961
959
self .right .index ,
962
960
self .left .index ,
963
961
right_indexer ,
964
- left_indexer ,
965
962
how = "left" ,
966
963
)
967
964
else :
@@ -979,9 +976,8 @@ def _create_join_index(
979
976
index : Index ,
980
977
other_index : Index ,
981
978
indexer ,
982
- other_indexer ,
983
979
how : str = "left" ,
984
- ):
980
+ ) -> Index :
985
981
"""
986
982
Create a join index by rearranging one index to match another
987
983
@@ -1126,7 +1122,7 @@ def _get_merge_keys(self):
1126
1122
1127
1123
return left_keys , right_keys , join_names
1128
1124
1129
- def _maybe_coerce_merge_keys (self ):
1125
+ def _maybe_coerce_merge_keys (self ) -> None :
1130
1126
# we have valid merges but we may have to further
1131
1127
# coerce these if they are originally incompatible types
1132
1128
#
@@ -1285,7 +1281,7 @@ def _create_cross_configuration(
1285
1281
cross_col ,
1286
1282
)
1287
1283
1288
- def _validate_specification (self ):
1284
+ def _validate_specification (self ) -> None :
1289
1285
if self .how == "cross" :
1290
1286
if (
1291
1287
self .left_index
@@ -1372,7 +1368,7 @@ def _validate_specification(self):
1372
1368
if self .how != "cross" and len (self .right_on ) != len (self .left_on ):
1373
1369
raise ValueError ("len(right_on) must equal len(left_on)" )
1374
1370
1375
- def _validate (self , validate : str ):
1371
+ def _validate (self , validate : str ) -> None :
1376
1372
1377
1373
# Check uniqueness of each
1378
1374
if self .left_index :
@@ -1479,10 +1475,10 @@ def restore_dropped_levels_multijoin(
1479
1475
left : MultiIndex ,
1480
1476
right : MultiIndex ,
1481
1477
dropped_level_names ,
1482
- join_index ,
1483
- lindexer ,
1484
- rindexer ,
1485
- ):
1478
+ join_index : Index ,
1479
+ lindexer : np . ndarray ,
1480
+ rindexer : np . ndarray ,
1481
+ ) -> Tuple [ List [ Index ], np . ndarray , List [ Hashable ]] :
1486
1482
"""
1487
1483
*this is an internal non-public method*
1488
1484
@@ -1500,7 +1496,7 @@ def restore_dropped_levels_multijoin(
1500
1496
right index
1501
1497
dropped_level_names : str array
1502
1498
list of non-common level names
1503
- join_index : MultiIndex
1499
+ join_index : Index
1504
1500
the index of the join between the
1505
1501
common levels of left and right
1506
1502
lindexer : intp array
@@ -1514,8 +1510,8 @@ def restore_dropped_levels_multijoin(
1514
1510
levels of combined multiindexes
1515
1511
labels : intp array
1516
1512
labels of combined multiindexes
1517
- names : str array
1518
- names of combined multiindexes
1513
+ names : List[Hashable]
1514
+ names of combined multiindex levels
1519
1515
1520
1516
"""
1521
1517
@@ -1604,7 +1600,7 @@ def __init__(
1604
1600
sort = True , # factorize sorts
1605
1601
)
1606
1602
1607
- def get_result (self ):
1603
+ def get_result (self ) -> DataFrame :
1608
1604
join_index , left_indexer , right_indexer = self ._get_join_info ()
1609
1605
1610
1606
llabels , rlabels = _items_overlap_with_suffix (
@@ -1653,7 +1649,7 @@ def _asof_by_function(direction: str):
1653
1649
}
1654
1650
1655
1651
1656
- def _get_cython_type_upcast (dtype ):
1652
+ def _get_cython_type_upcast (dtype ) -> str :
1657
1653
""" Upcast a dtype to 'int64_t', 'double', or 'object' """
1658
1654
if is_integer_dtype (dtype ):
1659
1655
return "int64_t"
0 commit comments