From 7c07ad31fe8fe68fa91e73321d02ccee012fdaf6 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 22 Jun 2023 16:39:24 -0700 Subject: [PATCH 1/2] REF: simplify merge code --- pandas/core/reshape/merge.py | 97 ++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 4d1f8bd6301d0..9c8862389fc35 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -717,13 +717,23 @@ def __init__( self.left_on = self.right_on = [cross_col] self._cross = cross_col - # note this function has side effects ( self.left_join_keys, self.right_join_keys, self.join_names, + left_drop, + right_drop, ) = self._get_merge_keys() + if left_drop: + self.left = self.left._drop_labels_or_levels(left_drop) + + if right_drop: + self.right = self.right._drop_labels_or_levels(right_drop) + + self._maybe_require_matching_dtypes(self.left_join_keys, self.right_join_keys) + self._validate_tolerance(self.left_join_keys) + # validate the merge keys dtypes. We may need to coerce # to avoid incompatible dtypes self._maybe_coerce_merge_keys() @@ -732,7 +742,15 @@ def __init__( # check if columns specified as unique # are in fact unique. if validate is not None: - self._validate(validate) + self._validate_validate_kwd(validate) + + def _maybe_require_matching_dtypes(self, left_join_keys, right_join_keys) -> None: + # Overridden by AsOfMerge + pass + + def _validate_tolerance(self, left_join_keys): + # Overridden by AsOfMerge + pass @final def _reindex_and_concat( @@ -1127,24 +1145,21 @@ def _create_join_index( index = index.append(Index([fill_value])) return index.take(indexer) + @final def _get_merge_keys( self, - ) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]: + ) -> tuple[ + list[ArrayLike], + list[ArrayLike], + list[Hashable], + list[Hashable], + list[Hashable], + ]: """ - Note: has side effects (copy/delete key columns) - - Parameters - ---------- - left - right - on - Returns ------- - left_keys, right_keys, join_names + left_keys, right_keys, join_names, left_drop, right_drop """ - # left_keys, right_keys entries can actually be anything listlike - # with a 'dtype' attr left_keys: list[ArrayLike] = [] right_keys: list[ArrayLike] = [] join_names: list[Hashable] = [] @@ -1264,13 +1279,7 @@ def _get_merge_keys( else: left_keys = [self.left.index._values] - if left_drop: - self.left = self.left._drop_labels_or_levels(left_drop) - - if right_drop: - self.right = self.right._drop_labels_or_levels(right_drop) - - return left_keys, right_keys, join_names + return left_keys, right_keys, join_names, left_drop, right_drop @final def _maybe_coerce_merge_keys(self) -> None: @@ -1556,7 +1565,8 @@ def _validate_left_right_on(self, left_on, right_on): return left_on, right_on - def _validate(self, validate: str) -> None: + @final + def _validate_validate_kwd(self, validate: str) -> None: # Check uniqueness of each if self.left_index: left_unique = self.orig_left.index.is_unique @@ -1811,19 +1821,14 @@ def __init__( def get_result(self, copy: bool | None = True) -> DataFrame: join_index, left_indexer, right_indexer = self._get_join_info() - llabels, rlabels = _items_overlap_with_suffix( - self.left._info_axis, self.right._info_axis, self.suffixes - ) - left_join_indexer: npt.NDArray[np.intp] | None right_join_indexer: npt.NDArray[np.intp] | None if self.fill_method == "ffill": if left_indexer is None: raise TypeError("left_indexer cannot be None") - left_indexer, right_indexer = cast(np.ndarray, left_indexer), cast( - np.ndarray, right_indexer - ) + left_indexer = cast("npt.NDArray[np.intp]", left_indexer) + right_indexer = cast("npt.NDArray[np.intp]", right_indexer) left_join_indexer = libjoin.ffill_indexer(left_indexer) right_join_indexer = libjoin.ffill_indexer(right_indexer) else: @@ -1888,6 +1893,18 @@ def __init__( self.allow_exact_matches = allow_exact_matches self.direction = direction + # check 'direction' is valid + if self.direction not in ["backward", "forward", "nearest"]: + raise MergeError(f"direction invalid: {self.direction}") + + # validate allow_exact_matches + if not is_bool(self.allow_exact_matches): + msg = ( + "allow_exact_matches must be boolean, " + f"passed {self.allow_exact_matches}" + ) + raise MergeError(msg) + _OrderedMerge.__init__( self, left, @@ -1975,17 +1992,10 @@ def _validate_left_right_on(self, left_on, right_on): left_on = self.left_by + list(left_on) right_on = self.right_by + list(right_on) - # check 'direction' is valid - if self.direction not in ["backward", "forward", "nearest"]: - raise MergeError(f"direction invalid: {self.direction}") - return left_on, right_on - def _get_merge_keys( - self, - ) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]: - # note this function has side effects - (left_join_keys, right_join_keys, join_names) = super()._get_merge_keys() + def _maybe_require_matching_dtypes(self, left_join_keys, right_join_keys) -> None: + # TODO: why do we do this for AsOfMerge but not the others? # validate index types are the same for i, (lk, rk) in enumerate(zip(left_join_keys, right_join_keys)): @@ -2012,6 +2022,7 @@ def _get_merge_keys( ) raise MergeError(msg) + def _validate_tolerance(self, left_join_keys) -> None: # validate tolerance; datetime.timedelta or Timedelta if we have a DTI if self.tolerance is not None: if self.left_index: @@ -2046,16 +2057,6 @@ def _get_merge_keys( else: raise MergeError("key must be integer, timestamp or float") - # validate allow_exact_matches - if not is_bool(self.allow_exact_matches): - msg = ( - "allow_exact_matches must be boolean, " - f"passed {self.allow_exact_matches}" - ) - raise MergeError(msg) - - return left_join_keys, right_join_keys, join_names - def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: """return the join indexers""" From e6f0313798db925c25d8b8f9b3e29f54766f3ae3 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 22 Jun 2023 19:37:16 -0700 Subject: [PATCH 2/2] typing --- pandas/core/reshape/merge.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 9c8862389fc35..d406145e62ad7 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -744,11 +744,13 @@ def __init__( if validate is not None: self._validate_validate_kwd(validate) - def _maybe_require_matching_dtypes(self, left_join_keys, right_join_keys) -> None: + def _maybe_require_matching_dtypes( + self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike] + ) -> None: # Overridden by AsOfMerge pass - def _validate_tolerance(self, left_join_keys): + def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: # Overridden by AsOfMerge pass @@ -1994,7 +1996,9 @@ def _validate_left_right_on(self, left_on, right_on): return left_on, right_on - def _maybe_require_matching_dtypes(self, left_join_keys, right_join_keys) -> None: + def _maybe_require_matching_dtypes( + self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike] + ) -> None: # TODO: why do we do this for AsOfMerge but not the others? # validate index types are the same @@ -2022,7 +2026,7 @@ def _maybe_require_matching_dtypes(self, left_join_keys, right_join_keys) -> Non ) raise MergeError(msg) - def _validate_tolerance(self, left_join_keys) -> None: + def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: # validate tolerance; datetime.timedelta or Timedelta if we have a DTI if self.tolerance is not None: if self.left_index: