Skip to content

Commit eac2c00

Browse files
authored
REF: simplify merge code (#53808)
* REF: simplify merge code * typing
1 parent 8bab235 commit eac2c00

File tree

1 file changed

+53
-48
lines changed

1 file changed

+53
-48
lines changed

pandas/core/reshape/merge.py

+53-48
Original file line numberDiff line numberDiff line change
@@ -717,13 +717,23 @@ def __init__(
717717
self.left_on = self.right_on = [cross_col]
718718
self._cross = cross_col
719719

720-
# note this function has side effects
721720
(
722721
self.left_join_keys,
723722
self.right_join_keys,
724723
self.join_names,
724+
left_drop,
725+
right_drop,
725726
) = self._get_merge_keys()
726727

728+
if left_drop:
729+
self.left = self.left._drop_labels_or_levels(left_drop)
730+
731+
if right_drop:
732+
self.right = self.right._drop_labels_or_levels(right_drop)
733+
734+
self._maybe_require_matching_dtypes(self.left_join_keys, self.right_join_keys)
735+
self._validate_tolerance(self.left_join_keys)
736+
727737
# validate the merge keys dtypes. We may need to coerce
728738
# to avoid incompatible dtypes
729739
self._maybe_coerce_merge_keys()
@@ -732,7 +742,17 @@ def __init__(
732742
# check if columns specified as unique
733743
# are in fact unique.
734744
if validate is not None:
735-
self._validate(validate)
745+
self._validate_validate_kwd(validate)
746+
747+
def _maybe_require_matching_dtypes(
748+
self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike]
749+
) -> None:
750+
# Overridden by AsOfMerge
751+
pass
752+
753+
def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None:
754+
# Overridden by AsOfMerge
755+
pass
736756

737757
@final
738758
def _reindex_and_concat(
@@ -1127,24 +1147,21 @@ def _create_join_index(
11271147
index = index.append(Index([fill_value]))
11281148
return index.take(indexer)
11291149

1150+
@final
11301151
def _get_merge_keys(
11311152
self,
1132-
) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]:
1153+
) -> tuple[
1154+
list[ArrayLike],
1155+
list[ArrayLike],
1156+
list[Hashable],
1157+
list[Hashable],
1158+
list[Hashable],
1159+
]:
11331160
"""
1134-
Note: has side effects (copy/delete key columns)
1135-
1136-
Parameters
1137-
----------
1138-
left
1139-
right
1140-
on
1141-
11421161
Returns
11431162
-------
1144-
left_keys, right_keys, join_names
1163+
left_keys, right_keys, join_names, left_drop, right_drop
11451164
"""
1146-
# left_keys, right_keys entries can actually be anything listlike
1147-
# with a 'dtype' attr
11481165
left_keys: list[ArrayLike] = []
11491166
right_keys: list[ArrayLike] = []
11501167
join_names: list[Hashable] = []
@@ -1264,13 +1281,7 @@ def _get_merge_keys(
12641281
else:
12651282
left_keys = [self.left.index._values]
12661283

1267-
if left_drop:
1268-
self.left = self.left._drop_labels_or_levels(left_drop)
1269-
1270-
if right_drop:
1271-
self.right = self.right._drop_labels_or_levels(right_drop)
1272-
1273-
return left_keys, right_keys, join_names
1284+
return left_keys, right_keys, join_names, left_drop, right_drop
12741285

12751286
@final
12761287
def _maybe_coerce_merge_keys(self) -> None:
@@ -1556,7 +1567,8 @@ def _validate_left_right_on(self, left_on, right_on):
15561567

15571568
return left_on, right_on
15581569

1559-
def _validate(self, validate: str) -> None:
1570+
@final
1571+
def _validate_validate_kwd(self, validate: str) -> None:
15601572
# Check uniqueness of each
15611573
if self.left_index:
15621574
left_unique = self.orig_left.index.is_unique
@@ -1811,19 +1823,14 @@ def __init__(
18111823
def get_result(self, copy: bool | None = True) -> DataFrame:
18121824
join_index, left_indexer, right_indexer = self._get_join_info()
18131825

1814-
llabels, rlabels = _items_overlap_with_suffix(
1815-
self.left._info_axis, self.right._info_axis, self.suffixes
1816-
)
1817-
18181826
left_join_indexer: npt.NDArray[np.intp] | None
18191827
right_join_indexer: npt.NDArray[np.intp] | None
18201828

18211829
if self.fill_method == "ffill":
18221830
if left_indexer is None:
18231831
raise TypeError("left_indexer cannot be None")
1824-
left_indexer, right_indexer = cast(np.ndarray, left_indexer), cast(
1825-
np.ndarray, right_indexer
1826-
)
1832+
left_indexer = cast("npt.NDArray[np.intp]", left_indexer)
1833+
right_indexer = cast("npt.NDArray[np.intp]", right_indexer)
18271834
left_join_indexer = libjoin.ffill_indexer(left_indexer)
18281835
right_join_indexer = libjoin.ffill_indexer(right_indexer)
18291836
else:
@@ -1888,6 +1895,18 @@ def __init__(
18881895
self.allow_exact_matches = allow_exact_matches
18891896
self.direction = direction
18901897

1898+
# check 'direction' is valid
1899+
if self.direction not in ["backward", "forward", "nearest"]:
1900+
raise MergeError(f"direction invalid: {self.direction}")
1901+
1902+
# validate allow_exact_matches
1903+
if not is_bool(self.allow_exact_matches):
1904+
msg = (
1905+
"allow_exact_matches must be boolean, "
1906+
f"passed {self.allow_exact_matches}"
1907+
)
1908+
raise MergeError(msg)
1909+
18911910
_OrderedMerge.__init__(
18921911
self,
18931912
left,
@@ -1975,17 +1994,12 @@ def _validate_left_right_on(self, left_on, right_on):
19751994
left_on = self.left_by + list(left_on)
19761995
right_on = self.right_by + list(right_on)
19771996

1978-
# check 'direction' is valid
1979-
if self.direction not in ["backward", "forward", "nearest"]:
1980-
raise MergeError(f"direction invalid: {self.direction}")
1981-
19821997
return left_on, right_on
19831998

1984-
def _get_merge_keys(
1985-
self,
1986-
) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]:
1987-
# note this function has side effects
1988-
(left_join_keys, right_join_keys, join_names) = super()._get_merge_keys()
1999+
def _maybe_require_matching_dtypes(
2000+
self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike]
2001+
) -> None:
2002+
# TODO: why do we do this for AsOfMerge but not the others?
19892003

19902004
# validate index types are the same
19912005
for i, (lk, rk) in enumerate(zip(left_join_keys, right_join_keys)):
@@ -2012,6 +2026,7 @@ def _get_merge_keys(
20122026
)
20132027
raise MergeError(msg)
20142028

2029+
def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None:
20152030
# validate tolerance; datetime.timedelta or Timedelta if we have a DTI
20162031
if self.tolerance is not None:
20172032
if self.left_index:
@@ -2046,16 +2061,6 @@ def _get_merge_keys(
20462061
else:
20472062
raise MergeError("key must be integer, timestamp or float")
20482063

2049-
# validate allow_exact_matches
2050-
if not is_bool(self.allow_exact_matches):
2051-
msg = (
2052-
"allow_exact_matches must be boolean, "
2053-
f"passed {self.allow_exact_matches}"
2054-
)
2055-
raise MergeError(msg)
2056-
2057-
return left_join_keys, right_join_keys, join_names
2058-
20592064
def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
20602065
"""return the join indexers"""
20612066

0 commit comments

Comments
 (0)