diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 4d1f8bd6301d0..d406145e62ad7 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -717,13 +717,23 @@ def __init__( self.left_on = self.right_on = [cross_col] self._cross = cross_col - # note this function has side effects ( self.left_join_keys, self.right_join_keys, self.join_names, + left_drop, + right_drop, ) = self._get_merge_keys() + if left_drop: + self.left = self.left._drop_labels_or_levels(left_drop) + + if right_drop: + self.right = self.right._drop_labels_or_levels(right_drop) + + self._maybe_require_matching_dtypes(self.left_join_keys, self.right_join_keys) + self._validate_tolerance(self.left_join_keys) + # validate the merge keys dtypes. We may need to coerce # to avoid incompatible dtypes self._maybe_coerce_merge_keys() @@ -732,7 +742,17 @@ def __init__( # check if columns specified as unique # are in fact unique. if validate is not None: - self._validate(validate) + self._validate_validate_kwd(validate) + + def _maybe_require_matching_dtypes( + self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike] + ) -> None: + # Overridden by AsOfMerge + pass + + def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: + # Overridden by AsOfMerge + pass @final def _reindex_and_concat( @@ -1127,24 +1147,21 @@ def _create_join_index( index = index.append(Index([fill_value])) return index.take(indexer) + @final def _get_merge_keys( self, - ) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]: + ) -> tuple[ + list[ArrayLike], + list[ArrayLike], + list[Hashable], + list[Hashable], + list[Hashable], + ]: """ - Note: has side effects (copy/delete key columns) - - Parameters - ---------- - left - right - on - Returns ------- - left_keys, right_keys, join_names + left_keys, right_keys, join_names, left_drop, right_drop """ - # left_keys, right_keys entries can actually be anything listlike - # with a 'dtype' attr left_keys: list[ArrayLike] = [] right_keys: list[ArrayLike] = [] join_names: list[Hashable] = [] @@ -1264,13 +1281,7 @@ def _get_merge_keys( else: left_keys = [self.left.index._values] - if left_drop: - self.left = self.left._drop_labels_or_levels(left_drop) - - if right_drop: - self.right = self.right._drop_labels_or_levels(right_drop) - - return left_keys, right_keys, join_names + return left_keys, right_keys, join_names, left_drop, right_drop @final def _maybe_coerce_merge_keys(self) -> None: @@ -1556,7 +1567,8 @@ def _validate_left_right_on(self, left_on, right_on): return left_on, right_on - def _validate(self, validate: str) -> None: + @final + def _validate_validate_kwd(self, validate: str) -> None: # Check uniqueness of each if self.left_index: left_unique = self.orig_left.index.is_unique @@ -1811,19 +1823,14 @@ def __init__( def get_result(self, copy: bool | None = True) -> DataFrame: join_index, left_indexer, right_indexer = self._get_join_info() - llabels, rlabels = _items_overlap_with_suffix( - self.left._info_axis, self.right._info_axis, self.suffixes - ) - left_join_indexer: npt.NDArray[np.intp] | None right_join_indexer: npt.NDArray[np.intp] | None if self.fill_method == "ffill": if left_indexer is None: raise TypeError("left_indexer cannot be None") - left_indexer, right_indexer = cast(np.ndarray, left_indexer), cast( - np.ndarray, right_indexer - ) + left_indexer = cast("npt.NDArray[np.intp]", left_indexer) + right_indexer = cast("npt.NDArray[np.intp]", right_indexer) left_join_indexer = libjoin.ffill_indexer(left_indexer) right_join_indexer = libjoin.ffill_indexer(right_indexer) else: @@ -1888,6 +1895,18 @@ def __init__( self.allow_exact_matches = allow_exact_matches self.direction = direction + # check 'direction' is valid + if self.direction not in ["backward", "forward", "nearest"]: + raise MergeError(f"direction invalid: {self.direction}") + + # validate allow_exact_matches + if not is_bool(self.allow_exact_matches): + msg = ( + "allow_exact_matches must be boolean, " + f"passed {self.allow_exact_matches}" + ) + raise MergeError(msg) + _OrderedMerge.__init__( self, left, @@ -1975,17 +1994,12 @@ def _validate_left_right_on(self, left_on, right_on): left_on = self.left_by + list(left_on) right_on = self.right_by + list(right_on) - # check 'direction' is valid - if self.direction not in ["backward", "forward", "nearest"]: - raise MergeError(f"direction invalid: {self.direction}") - return left_on, right_on - def _get_merge_keys( - self, - ) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]: - # note this function has side effects - (left_join_keys, right_join_keys, join_names) = super()._get_merge_keys() + def _maybe_require_matching_dtypes( + self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike] + ) -> None: + # TODO: why do we do this for AsOfMerge but not the others? # validate index types are the same for i, (lk, rk) in enumerate(zip(left_join_keys, right_join_keys)): @@ -2012,6 +2026,7 @@ def _get_merge_keys( ) raise MergeError(msg) + def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: # validate tolerance; datetime.timedelta or Timedelta if we have a DTI if self.tolerance is not None: if self.left_index: @@ -2046,16 +2061,6 @@ def _get_merge_keys( else: raise MergeError("key must be integer, timestamp or float") - # validate allow_exact_matches - if not is_bool(self.allow_exact_matches): - msg = ( - "allow_exact_matches must be boolean, " - f"passed {self.allow_exact_matches}" - ) - raise MergeError(msg) - - return left_join_keys, right_join_keys, join_names - def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: """return the join indexers"""