6
6
import datetime
7
7
from functools import partial
8
8
import string
9
+ from typing import TYPE_CHECKING , Optional , Tuple , Union
9
10
import warnings
10
11
11
12
import numpy as np
39
40
from pandas .core .dtypes .missing import isna , na_value_for_dtype
40
41
41
42
from pandas import Categorical , Index , MultiIndex
43
+ from pandas ._typing import FrameOrSeries
42
44
import pandas .core .algorithms as algos
43
45
from pandas .core .arrays .categorical import _recode_for_categories
44
46
import pandas .core .common as com
45
47
from pandas .core .frame import _merge_doc
46
48
from pandas .core .internals import _transform_index , concatenate_block_managers
47
49
from pandas .core .sorting import is_int64_overflow_possible
48
50
51
+ if TYPE_CHECKING :
52
+ from pandas import DataFrame , Series # noqa:F401
53
+
49
54
50
55
@Substitution ("\n left : DataFrame" )
51
56
@Appender (_merge_doc , indents = 0 )
52
57
def merge (
53
58
left ,
54
59
right ,
55
- how = "inner" ,
60
+ how : str = "inner" ,
56
61
on = None ,
57
62
left_on = None ,
58
63
right_on = None ,
59
- left_index = False ,
60
- right_index = False ,
61
- sort = False ,
64
+ left_index : bool = False ,
65
+ right_index : bool = False ,
66
+ sort : bool = False ,
62
67
suffixes = ("_x" , "_y" ),
63
- copy = True ,
64
- indicator = False ,
68
+ copy : bool = True ,
69
+ indicator : bool = False ,
65
70
validate = None ,
66
71
):
67
72
op = _MergeOperation (
@@ -86,7 +91,9 @@ def merge(
86
91
merge .__doc__ = _merge_doc % "\n left : DataFrame"
87
92
88
93
89
- def _groupby_and_merge (by , on , left , right , _merge_pieces , check_duplicates = True ):
94
+ def _groupby_and_merge (
95
+ by , on , left , right , _merge_pieces , check_duplicates : bool = True
96
+ ):
90
97
"""
91
98
groupby & merge; we are always performing a left-by type operation
92
99
@@ -172,7 +179,7 @@ def merge_ordered(
172
179
right_by = None ,
173
180
fill_method = None ,
174
181
suffixes = ("_x" , "_y" ),
175
- how = "outer" ,
182
+ how : str = "outer" ,
176
183
):
177
184
"""
178
185
Perform merge with optional filling/interpolation.
@@ -298,14 +305,14 @@ def merge_asof(
298
305
on = None ,
299
306
left_on = None ,
300
307
right_on = None ,
301
- left_index = False ,
302
- right_index = False ,
308
+ left_index : bool = False ,
309
+ right_index : bool = False ,
303
310
by = None ,
304
311
left_by = None ,
305
312
right_by = None ,
306
313
suffixes = ("_x" , "_y" ),
307
314
tolerance = None ,
308
- allow_exact_matches = True ,
315
+ allow_exact_matches : bool = True ,
309
316
direction = "backward" ,
310
317
):
311
318
"""
@@ -533,33 +540,33 @@ def merge_asof(
533
540
# TODO: only copy DataFrames when modification necessary
534
541
class _MergeOperation :
535
542
"""
536
- Perform a database (SQL) merge operation between two DataFrame objects
537
- using either columns as keys or their row indexes
543
+ Perform a database (SQL) merge operation between two DataFrame or Series
544
+ objects using either columns as keys or their row indexes
538
545
"""
539
546
540
547
_merge_type = "merge"
541
548
542
549
def __init__ (
543
550
self ,
544
- left ,
545
- right ,
546
- how = "inner" ,
551
+ left : Union [ "Series" , "DataFrame" ] ,
552
+ right : Union [ "Series" , "DataFrame" ] ,
553
+ how : str = "inner" ,
547
554
on = None ,
548
555
left_on = None ,
549
556
right_on = None ,
550
557
axis = 1 ,
551
- left_index = False ,
552
- right_index = False ,
553
- sort = True ,
558
+ left_index : bool = False ,
559
+ right_index : bool = False ,
560
+ sort : bool = True ,
554
561
suffixes = ("_x" , "_y" ),
555
- copy = True ,
556
- indicator = False ,
562
+ copy : bool = True ,
563
+ indicator : bool = False ,
557
564
validate = None ,
558
565
):
559
- left = validate_operand (left )
560
- right = validate_operand (right )
561
- self .left = self .orig_left = left
562
- self .right = self .orig_right = right
566
+ _left = _validate_operand (left )
567
+ _right = _validate_operand (right )
568
+ self .left = self .orig_left = _validate_operand ( _left ) # type: "DataFrame"
569
+ self .right = self .orig_right = _validate_operand ( _right ) # type: "DataFrame"
563
570
self .how = how
564
571
self .axis = axis
565
572
@@ -577,7 +584,7 @@ def __init__(
577
584
self .indicator = indicator
578
585
579
586
if isinstance (self .indicator , str ):
580
- self .indicator_name = self .indicator
587
+ self .indicator_name = self .indicator # type: Optional[str]
581
588
elif isinstance (self .indicator , bool ):
582
589
self .indicator_name = "_merge" if self .indicator else None
583
590
else :
@@ -597,11 +604,11 @@ def __init__(
597
604
)
598
605
599
606
# warn user when merging between different levels
600
- if left .columns .nlevels != right .columns .nlevels :
607
+ if _left .columns .nlevels != _right .columns .nlevels :
601
608
msg = (
602
609
"merging between different levels can give an unintended "
603
610
"result ({left} levels on the left, {right} on the right)"
604
- ).format (left = left .columns .nlevels , right = right .columns .nlevels )
611
+ ).format (left = _left .columns .nlevels , right = _right .columns .nlevels )
605
612
warnings .warn (msg , UserWarning )
606
613
607
614
self ._validate_specification ()
@@ -658,7 +665,9 @@ def get_result(self):
658
665
659
666
return result
660
667
661
- def _indicator_pre_merge (self , left , right ):
668
+ def _indicator_pre_merge (
669
+ self , left : "DataFrame" , right : "DataFrame"
670
+ ) -> Tuple ["DataFrame" , "DataFrame" ]:
662
671
663
672
columns = left .columns .union (right .columns )
664
673
@@ -878,7 +887,12 @@ def _get_join_info(self):
878
887
return join_index , left_indexer , right_indexer
879
888
880
889
def _create_join_index (
881
- self , index , other_index , indexer , other_indexer , how = "left"
890
+ self ,
891
+ index : Index ,
892
+ other_index : Index ,
893
+ indexer ,
894
+ other_indexer ,
895
+ how : str = "left" ,
882
896
):
883
897
"""
884
898
Create a join index by rearranging one index to match another
@@ -1263,7 +1277,9 @@ def _validate(self, validate: str):
1263
1277
raise ValueError ("Not a valid argument for validate" )
1264
1278
1265
1279
1266
- def _get_join_indexers (left_keys , right_keys , sort = False , how = "inner" , ** kwargs ):
1280
+ def _get_join_indexers (
1281
+ left_keys , right_keys , sort : bool = False , how : str = "inner" , ** kwargs
1282
+ ):
1267
1283
"""
1268
1284
1269
1285
Parameters
@@ -1410,13 +1426,13 @@ def __init__(
1410
1426
on = None ,
1411
1427
left_on = None ,
1412
1428
right_on = None ,
1413
- left_index = False ,
1414
- right_index = False ,
1429
+ left_index : bool = False ,
1430
+ right_index : bool = False ,
1415
1431
axis = 1 ,
1416
1432
suffixes = ("_x" , "_y" ),
1417
- copy = True ,
1433
+ copy : bool = True ,
1418
1434
fill_method = None ,
1419
- how = "outer" ,
1435
+ how : str = "outer" ,
1420
1436
):
1421
1437
1422
1438
self .fill_method = fill_method
@@ -1508,18 +1524,18 @@ def __init__(
1508
1524
on = None ,
1509
1525
left_on = None ,
1510
1526
right_on = None ,
1511
- left_index = False ,
1512
- right_index = False ,
1527
+ left_index : bool = False ,
1528
+ right_index : bool = False ,
1513
1529
by = None ,
1514
1530
left_by = None ,
1515
1531
right_by = None ,
1516
1532
axis = 1 ,
1517
1533
suffixes = ("_x" , "_y" ),
1518
- copy = True ,
1534
+ copy : bool = True ,
1519
1535
fill_method = None ,
1520
- how = "asof" ,
1536
+ how : str = "asof" ,
1521
1537
tolerance = None ,
1522
- allow_exact_matches = True ,
1538
+ allow_exact_matches : bool = True ,
1523
1539
direction = "backward" ,
1524
1540
):
1525
1541
@@ -1757,13 +1773,15 @@ def flip(xs):
1757
1773
return func (left_values , right_values , self .allow_exact_matches , tolerance )
1758
1774
1759
1775
1760
- def _get_multiindex_indexer (join_keys , index , sort ):
1776
+ def _get_multiindex_indexer (join_keys , index : MultiIndex , sort : bool ):
1761
1777
1762
1778
# bind `sort` argument
1763
1779
fkeys = partial (_factorize_keys , sort = sort )
1764
1780
1765
1781
# left & right join labels and num. of levels at each location
1766
- rcodes , lcodes , shape = map (list , zip (* map (fkeys , index .levels , join_keys )))
1782
+ mapped = (fkeys (index .levels [n ], join_keys [n ]) for n in range (len (index .levels )))
1783
+ zipped = zip (* mapped )
1784
+ rcodes , lcodes , shape = [list (x ) for x in zipped ]
1767
1785
if sort :
1768
1786
rcodes = list (map (np .take , rcodes , index .codes ))
1769
1787
else :
@@ -1791,7 +1809,7 @@ def _get_multiindex_indexer(join_keys, index, sort):
1791
1809
return libjoin .left_outer_join (lkey , rkey , count , sort = sort )
1792
1810
1793
1811
1794
- def _get_single_indexer (join_key , index , sort = False ):
1812
+ def _get_single_indexer (join_key , index , sort : bool = False ):
1795
1813
left_key , right_key , count = _factorize_keys (join_key , index , sort = sort )
1796
1814
1797
1815
left_indexer , right_indexer = libjoin .left_outer_join (
@@ -1801,7 +1819,7 @@ def _get_single_indexer(join_key, index, sort=False):
1801
1819
return left_indexer , right_indexer
1802
1820
1803
1821
1804
- def _left_join_on_index (left_ax , right_ax , join_keys , sort = False ):
1822
+ def _left_join_on_index (left_ax : Index , right_ax : Index , join_keys , sort : bool = False ):
1805
1823
if len (join_keys ) > 1 :
1806
1824
if not (
1807
1825
(isinstance (right_ax , MultiIndex ) and len (join_keys ) == right_ax .nlevels )
@@ -1915,7 +1933,7 @@ def _factorize_keys(lk, rk, sort=True):
1915
1933
return llab , rlab , count
1916
1934
1917
1935
1918
- def _sort_labels (uniques , left , right ):
1936
+ def _sort_labels (uniques : np . ndarray , left , right ):
1919
1937
if not isinstance (uniques , np .ndarray ):
1920
1938
# tuplesafe
1921
1939
uniques = Index (uniques ).values
@@ -1930,7 +1948,7 @@ def _sort_labels(uniques, left, right):
1930
1948
return new_left , new_right
1931
1949
1932
1950
1933
- def _get_join_keys (llab , rlab , shape , sort ):
1951
+ def _get_join_keys (llab , rlab , shape , sort : bool ):
1934
1952
1935
1953
# how many levels can be done without overflow
1936
1954
pred = lambda i : not is_int64_overflow_possible (shape [:i ])
@@ -1970,7 +1988,7 @@ def _any(x) -> bool:
1970
1988
return x is not None and com .any_not_none (* x )
1971
1989
1972
1990
1973
- def validate_operand (obj ) :
1991
+ def _validate_operand (obj : FrameOrSeries ) -> "DataFrame" :
1974
1992
if isinstance (obj , ABCDataFrame ):
1975
1993
return obj
1976
1994
elif isinstance (obj , ABCSeries ):
@@ -1985,7 +2003,7 @@ def validate_operand(obj):
1985
2003
)
1986
2004
1987
2005
1988
- def _items_overlap_with_suffix (left , lsuffix , right , rsuffix ):
2006
+ def _items_overlap_with_suffix (left : Index , lsuffix , right : Index , rsuffix ):
1989
2007
"""
1990
2008
If two indices overlap, add suffixes to overlapping entries.
1991
2009
0 commit comments