diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index c01bf3931b27a..5ceac80c340ba 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -10,6 +10,8 @@ import string from typing import ( TYPE_CHECKING, + Hashable, + List, Optional, Tuple, cast, @@ -124,14 +126,13 @@ def merge( merge.__doc__ = _merge_doc % "\nleft : DataFrame" -def _groupby_and_merge(by, on, left: DataFrame, right: DataFrame, merge_pieces): +def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces): """ groupby & merge; we are always performing a left-by type operation Parameters ---------- by: field to group - on: duplicates field left: DataFrame right: DataFrame merge_pieces: function for merging @@ -307,9 +308,7 @@ def _merger(x, y): check = set(left_by).difference(left.columns) if len(check) != 0: raise KeyError(f"{check} not found in left columns") - result, _ = _groupby_and_merge( - left_by, on, left, right, lambda x, y: _merger(x, y) - ) + result, _ = _groupby_and_merge(left_by, left, right, lambda x, y: _merger(x, y)) elif right_by is not None: if isinstance(right_by, str): right_by = [right_by] @@ -317,7 +316,7 @@ def _merger(x, y): if len(check) != 0: raise KeyError(f"{check} not found in right columns") result, _ = _groupby_and_merge( - right_by, on, right, left, lambda x, y: _merger(y, x) + right_by, right, left, lambda x, y: _merger(y, x) ) else: result = _merger(left, right) @@ -708,7 +707,7 @@ def __init__( if validate is not None: self._validate(validate) - def get_result(self): + def get_result(self) -> DataFrame: if self.indicator: self.left, self.right = self._indicator_pre_merge(self.left, self.right) @@ -774,7 +773,7 @@ def _indicator_pre_merge( return left, right - def _indicator_post_merge(self, result): + def _indicator_post_merge(self, result: DataFrame) -> DataFrame: result["_left_indicator"] = result["_left_indicator"].fillna(0) result["_right_indicator"] = result["_right_indicator"].fillna(0) @@ -790,7 +789,7 @@ def _indicator_post_merge(self, result): result = result.drop(labels=["_left_indicator", "_right_indicator"], axis=1) return result - def _maybe_restore_index_levels(self, result): + def _maybe_restore_index_levels(self, result: DataFrame) -> None: """ Restore index levels specified as `on` parameters @@ -949,7 +948,6 @@ def _get_join_info(self): self.left.index, self.right.index, left_indexer, - right_indexer, how="right", ) else: @@ -961,7 +959,6 @@ def _get_join_info(self): self.right.index, self.left.index, right_indexer, - left_indexer, how="left", ) else: @@ -979,9 +976,8 @@ def _create_join_index( index: Index, other_index: Index, indexer, - other_indexer, how: str = "left", - ): + ) -> Index: """ Create a join index by rearranging one index to match another @@ -1126,7 +1122,7 @@ def _get_merge_keys(self): return left_keys, right_keys, join_names - def _maybe_coerce_merge_keys(self): + def _maybe_coerce_merge_keys(self) -> None: # we have valid merges but we may have to further # coerce these if they are originally incompatible types # @@ -1285,7 +1281,7 @@ def _create_cross_configuration( cross_col, ) - def _validate_specification(self): + def _validate_specification(self) -> None: if self.how == "cross": if ( self.left_index @@ -1372,7 +1368,7 @@ def _validate_specification(self): if self.how != "cross" and len(self.right_on) != len(self.left_on): raise ValueError("len(right_on) must equal len(left_on)") - def _validate(self, validate: str): + def _validate(self, validate: str) -> None: # Check uniqueness of each if self.left_index: @@ -1479,10 +1475,10 @@ def restore_dropped_levels_multijoin( left: MultiIndex, right: MultiIndex, dropped_level_names, - join_index, - lindexer, - rindexer, -): + join_index: Index, + lindexer: np.ndarray, + rindexer: np.ndarray, +) -> Tuple[List[Index], np.ndarray, List[Hashable]]: """ *this is an internal non-public method* @@ -1500,7 +1496,7 @@ def restore_dropped_levels_multijoin( right index dropped_level_names : str array list of non-common level names - join_index : MultiIndex + join_index : Index the index of the join between the common levels of left and right lindexer : intp array @@ -1514,8 +1510,8 @@ def restore_dropped_levels_multijoin( levels of combined multiindexes labels : intp array labels of combined multiindexes - names : str array - names of combined multiindexes + names : List[Hashable] + names of combined multiindex levels """ @@ -1604,7 +1600,7 @@ def __init__( sort=True, # factorize sorts ) - def get_result(self): + def get_result(self) -> DataFrame: join_index, left_indexer, right_indexer = self._get_join_info() llabels, rlabels = _items_overlap_with_suffix( @@ -1653,7 +1649,7 @@ def _asof_by_function(direction: str): } -def _get_cython_type_upcast(dtype): +def _get_cython_type_upcast(dtype) -> str: """ Upcast a dtype to 'int64_t', 'double', or 'object' """ if is_integer_dtype(dtype): return "int64_t"