Skip to content

Commit 2742623

Browse files
authored
TYP: reshape.merge (#53780)
1 parent 0d92b23 commit 2742623

File tree

1 file changed

+51
-41
lines changed

1 file changed

+51
-41
lines changed

pandas/core/reshape/merge.py

+51-41
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
from __future__ import annotations
55

6-
import copy as cp
76
import datetime
87
from functools import partial
98
import string
@@ -13,6 +12,7 @@
1312
Literal,
1413
Sequence,
1514
cast,
15+
final,
1616
)
1717
import uuid
1818
import warnings
@@ -655,8 +655,8 @@ class _MergeOperation:
655655
indicator: str | bool
656656
validate: str | None
657657
join_names: list[Hashable]
658-
right_join_keys: list[AnyArrayLike]
659-
left_join_keys: list[AnyArrayLike]
658+
right_join_keys: list[ArrayLike]
659+
left_join_keys: list[ArrayLike]
660660

661661
def __init__(
662662
self,
@@ -743,6 +743,7 @@ def __init__(
743743
if validate is not None:
744744
self._validate(validate)
745745

746+
@final
746747
def _reindex_and_concat(
747748
self,
748749
join_index: Index,
@@ -821,12 +822,14 @@ def get_result(self, copy: bool | None = True) -> DataFrame:
821822

822823
return result.__finalize__(self, method="merge")
823824

825+
@final
824826
def _maybe_drop_cross_column(
825827
self, result: DataFrame, cross_col: str | None
826828
) -> None:
827829
if cross_col is not None:
828830
del result[cross_col]
829831

832+
@final
830833
@cache_readonly
831834
def _indicator_name(self) -> str | None:
832835
if isinstance(self.indicator, str):
@@ -838,6 +841,7 @@ def _indicator_name(self) -> str | None:
838841
"indicator option can only accept boolean or string arguments"
839842
)
840843

844+
@final
841845
def _indicator_pre_merge(
842846
self, left: DataFrame, right: DataFrame
843847
) -> tuple[DataFrame, DataFrame]:
@@ -865,6 +869,7 @@ def _indicator_pre_merge(
865869

866870
return left, right
867871

872+
@final
868873
def _indicator_post_merge(self, result: DataFrame) -> DataFrame:
869874
result["_left_indicator"] = result["_left_indicator"].fillna(0)
870875
result["_right_indicator"] = result["_right_indicator"].fillna(0)
@@ -880,6 +885,7 @@ def _indicator_post_merge(self, result: DataFrame) -> DataFrame:
880885
result = result.drop(labels=["_left_indicator", "_right_indicator"], axis=1)
881886
return result
882887

888+
@final
883889
def _maybe_restore_index_levels(self, result: DataFrame) -> None:
884890
"""
885891
Restore index levels specified as `on` parameters
@@ -923,11 +929,12 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None:
923929
if names_to_restore:
924930
result.set_index(names_to_restore, inplace=True)
925931

932+
@final
926933
def _maybe_add_join_keys(
927934
self,
928935
result: DataFrame,
929-
left_indexer: np.ndarray | None,
930-
right_indexer: np.ndarray | None,
936+
left_indexer: npt.NDArray[np.intp] | None,
937+
right_indexer: npt.NDArray[np.intp] | None,
931938
) -> None:
932939
left_has_missing = None
933940
right_has_missing = None
@@ -1032,6 +1039,7 @@ def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]
10321039
self.left_join_keys, self.right_join_keys, sort=self.sort, how=self.how
10331040
)
10341041

1042+
@final
10351043
def _get_join_info(
10361044
self,
10371045
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
@@ -1093,6 +1101,7 @@ def _get_join_info(
10931101
join_index = default_index(0).set_names(join_index.name)
10941102
return join_index, left_indexer, right_indexer
10951103

1104+
@final
10961105
def _create_join_index(
10971106
self,
10981107
index: Index,
@@ -1129,7 +1138,7 @@ def _create_join_index(
11291138

11301139
def _get_merge_keys(
11311140
self,
1132-
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:
1141+
) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]:
11331142
"""
11341143
Note: has side effects (copy/delete key columns)
11351144
@@ -1145,8 +1154,8 @@ def _get_merge_keys(
11451154
"""
11461155
# left_keys, right_keys entries can actually be anything listlike
11471156
# with a 'dtype' attr
1148-
left_keys: list[AnyArrayLike] = []
1149-
right_keys: list[AnyArrayLike] = []
1157+
left_keys: list[ArrayLike] = []
1158+
right_keys: list[ArrayLike] = []
11501159
join_names: list[Hashable] = []
11511160
right_drop: list[Hashable] = []
11521161
left_drop: list[Hashable] = []
@@ -1169,11 +1178,13 @@ def _get_merge_keys(
11691178
# ugh, spaghetti re #733
11701179
if _any(self.left_on) and _any(self.right_on):
11711180
for lk, rk in zip(self.left_on, self.right_on):
1181+
lk = extract_array(lk, extract_numpy=True)
1182+
rk = extract_array(rk, extract_numpy=True)
11721183
if is_lkey(lk):
1173-
lk = cast(AnyArrayLike, lk)
1184+
lk = cast(ArrayLike, lk)
11741185
left_keys.append(lk)
11751186
if is_rkey(rk):
1176-
rk = cast(AnyArrayLike, rk)
1187+
rk = cast(ArrayLike, rk)
11771188
right_keys.append(rk)
11781189
join_names.append(None) # what to do?
11791190
else:
@@ -1185,7 +1196,7 @@ def _get_merge_keys(
11851196
join_names.append(rk)
11861197
else:
11871198
# work-around for merge_asof(right_index=True)
1188-
right_keys.append(right.index)
1199+
right_keys.append(right.index._values)
11891200
join_names.append(right.index.name)
11901201
else:
11911202
if not is_rkey(rk):
@@ -1196,7 +1207,7 @@ def _get_merge_keys(
11961207
right_keys.append(right._get_label_or_level_values(rk))
11971208
else:
11981209
# work-around for merge_asof(right_index=True)
1199-
right_keys.append(right.index)
1210+
right_keys.append(right.index._values)
12001211
if lk is not None and lk == rk: # FIXME: what about other NAs?
12011212
# avoid key upcast in corner case (length-0)
12021213
lk = cast(Hashable, lk)
@@ -1205,7 +1216,7 @@ def _get_merge_keys(
12051216
else:
12061217
left_drop.append(lk)
12071218
else:
1208-
rk = cast(AnyArrayLike, rk)
1219+
rk = cast(ArrayLike, rk)
12091220
right_keys.append(rk)
12101221
if lk is not None:
12111222
# Then we're either Hashable or a wrong-length arraylike,
@@ -1215,12 +1226,13 @@ def _get_merge_keys(
12151226
join_names.append(lk)
12161227
else:
12171228
# work-around for merge_asof(left_index=True)
1218-
left_keys.append(left.index)
1229+
left_keys.append(left.index._values)
12191230
join_names.append(left.index.name)
12201231
elif _any(self.left_on):
12211232
for k in self.left_on:
12221233
if is_lkey(k):
1223-
k = cast(AnyArrayLike, k)
1234+
k = extract_array(k, extract_numpy=True)
1235+
k = cast(ArrayLike, k)
12241236
left_keys.append(k)
12251237
join_names.append(None)
12261238
else:
@@ -1240,8 +1252,9 @@ def _get_merge_keys(
12401252
right_keys = [self.right.index._values]
12411253
elif _any(self.right_on):
12421254
for k in self.right_on:
1255+
k = extract_array(k, extract_numpy=True)
12431256
if is_rkey(k):
1244-
k = cast(AnyArrayLike, k)
1257+
k = cast(ArrayLike, k)
12451258
right_keys.append(k)
12461259
join_names.append(None)
12471260
else:
@@ -1268,6 +1281,7 @@ def _get_merge_keys(
12681281

12691282
return left_keys, right_keys, join_names
12701283

1284+
@final
12711285
def _maybe_coerce_merge_keys(self) -> None:
12721286
# we have valid merges but we may have to further
12731287
# coerce these if they are originally incompatible types
@@ -1432,6 +1446,7 @@ def _maybe_coerce_merge_keys(self) -> None:
14321446
self.right = self.right.copy()
14331447
self.right[name] = self.right[name].astype(typ)
14341448

1449+
@final
14351450
def _create_cross_configuration(
14361451
self, left: DataFrame, right: DataFrame
14371452
) -> tuple[DataFrame, DataFrame, JoinHow, str]:
@@ -1610,11 +1625,10 @@ def _validate(self, validate: str) -> None:
16101625

16111626

16121627
def get_join_indexers(
1613-
left_keys: list[AnyArrayLike],
1614-
right_keys: list[AnyArrayLike],
1628+
left_keys: list[ArrayLike],
1629+
right_keys: list[ArrayLike],
16151630
sort: bool = False,
16161631
how: MergeHow | Literal["asof"] = "inner",
1617-
**kwargs,
16181632
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
16191633
"""
16201634
@@ -1667,7 +1681,7 @@ def get_join_indexers(
16671681

16681682
lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort, how=how)
16691683
# preserve left frame order if how == 'left' and sort == False
1670-
kwargs = cp.copy(kwargs)
1684+
kwargs = {}
16711685
if how in ("left", "right"):
16721686
kwargs["sort"] = sort
16731687
join_func = {
@@ -1812,8 +1826,8 @@ def get_result(self, copy: bool | None = True) -> DataFrame:
18121826
self.left._info_axis, self.right._info_axis, self.suffixes
18131827
)
18141828

1815-
left_join_indexer: np.ndarray | None
1816-
right_join_indexer: np.ndarray | None
1829+
left_join_indexer: npt.NDArray[np.intp] | None
1830+
right_join_indexer: npt.NDArray[np.intp] | None
18171831

18181832
if self.fill_method == "ffill":
18191833
if left_indexer is None:
@@ -1984,7 +1998,7 @@ def _validate_left_right_on(self, left_on, right_on):
19841998

19851999
def _get_merge_keys(
19862000
self,
1987-
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:
2001+
) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]:
19882002
# note this function has side effects
19892003
(left_join_keys, right_join_keys, join_names) = super()._get_merge_keys()
19902004

@@ -2016,8 +2030,7 @@ def _get_merge_keys(
20162030
# validate tolerance; datetime.timedelta or Timedelta if we have a DTI
20172031
if self.tolerance is not None:
20182032
if self.left_index:
2019-
# Actually more specifically an Index
2020-
lt = cast(AnyArrayLike, self.left.index)
2033+
lt = self.left.index._values
20212034
else:
20222035
lt = left_join_keys[-1]
20232036

@@ -2026,19 +2039,19 @@ def _get_merge_keys(
20262039
f"with type {repr(lt.dtype)}"
20272040
)
20282041

2029-
if needs_i8_conversion(getattr(lt, "dtype", None)):
2042+
if needs_i8_conversion(lt.dtype):
20302043
if not isinstance(self.tolerance, datetime.timedelta):
20312044
raise MergeError(msg)
20322045
if self.tolerance < Timedelta(0):
20332046
raise MergeError("tolerance must be positive")
20342047

2035-
elif is_integer_dtype(lt):
2048+
elif is_integer_dtype(lt.dtype):
20362049
if not is_integer(self.tolerance):
20372050
raise MergeError(msg)
20382051
if self.tolerance < 0:
20392052
raise MergeError("tolerance must be positive")
20402053

2041-
elif is_float_dtype(lt):
2054+
elif is_float_dtype(lt.dtype):
20422055
if not is_number(self.tolerance):
20432056
raise MergeError(msg)
20442057
# error: Unsupported operand types for > ("int" and "Number")
@@ -2061,10 +2074,10 @@ def _get_merge_keys(
20612074
def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
20622075
"""return the join indexers"""
20632076

2064-
def flip(xs: list[AnyArrayLike]) -> np.ndarray:
2077+
def flip(xs: list[ArrayLike]) -> np.ndarray:
20652078
"""unlike np.transpose, this returns an array of tuples"""
20662079

2067-
def injection(obj: AnyArrayLike):
2080+
def injection(obj: ArrayLike):
20682081
if not isinstance(obj.dtype, ExtensionDtype):
20692082
# ndarray
20702083
return obj
@@ -2212,11 +2225,11 @@ def injection(obj: AnyArrayLike):
22122225

22132226

22142227
def _get_multiindex_indexer(
2215-
join_keys: list[AnyArrayLike], index: MultiIndex, sort: bool
2228+
join_keys: list[ArrayLike], index: MultiIndex, sort: bool
22162229
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
22172230
# left & right join labels and num. of levels at each location
22182231
mapped = (
2219-
_factorize_keys(index.levels[n], join_keys[n], sort=sort)
2232+
_factorize_keys(index.levels[n]._values, join_keys[n], sort=sort)
22202233
for n in range(index.nlevels)
22212234
)
22222235
zipped = zip(*mapped)
@@ -2249,7 +2262,7 @@ def _get_multiindex_indexer(
22492262

22502263

22512264
def _get_single_indexer(
2252-
join_key: AnyArrayLike, index: Index, sort: bool = False
2265+
join_key: ArrayLike, index: Index, sort: bool = False
22532266
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
22542267
left_key, right_key, count = _factorize_keys(join_key, index._values, sort=sort)
22552268

@@ -2294,7 +2307,7 @@ def _get_no_sort_one_missing_indexer(
22942307

22952308

22962309
def _left_join_on_index(
2297-
left_ax: Index, right_ax: Index, join_keys: list[AnyArrayLike], sort: bool = False
2310+
left_ax: Index, right_ax: Index, join_keys: list[ArrayLike], sort: bool = False
22982311
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp]]:
22992312
if isinstance(right_ax, MultiIndex):
23002313
left_indexer, right_indexer = _get_multiindex_indexer(
@@ -2315,8 +2328,8 @@ def _left_join_on_index(
23152328

23162329

23172330
def _factorize_keys(
2318-
lk: AnyArrayLike,
2319-
rk: AnyArrayLike,
2331+
lk: ArrayLike,
2332+
rk: ArrayLike,
23202333
sort: bool = True,
23212334
how: MergeHow | Literal["asof"] = "inner",
23222335
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
@@ -2327,9 +2340,9 @@ def _factorize_keys(
23272340
23282341
Parameters
23292342
----------
2330-
lk : ndarray, ExtensionArray, Index, or Series
2343+
lk : ndarray, ExtensionArray
23312344
Left key.
2332-
rk : ndarray, ExtensionArray, Index, or Series
2345+
rk : ndarray, ExtensionArray
23332346
Right key.
23342347
sort : bool, defaults to True
23352348
If True, the encoding is done such that the unique elements in the
@@ -2370,9 +2383,6 @@ def _factorize_keys(
23702383
>>> pd.core.reshape.merge._factorize_keys(lk, rk, sort=False)
23712384
(array([0, 1, 2]), array([0, 1]), 3)
23722385
"""
2373-
# Some pre-processing for non-ndarray lk / rk
2374-
lk = extract_array(lk, extract_numpy=True, extract_range=True)
2375-
rk = extract_array(rk, extract_numpy=True, extract_range=True)
23762386
# TODO: if either is a RangeIndex, we can likely factorize more efficiently?
23772387

23782388
if (

0 commit comments

Comments
 (0)