diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index bc23d50c634d5..2674b7ee95088 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -6,6 +6,7 @@ import datetime from functools import partial import string +from typing import TYPE_CHECKING, Optional, Tuple, Union import warnings import numpy as np @@ -39,6 +40,7 @@ from pandas.core.dtypes.missing import isna, na_value_for_dtype from pandas import Categorical, Index, MultiIndex +from pandas._typing import FrameOrSeries import pandas.core.algorithms as algos from pandas.core.arrays.categorical import _recode_for_categories import pandas.core.common as com @@ -46,22 +48,25 @@ from pandas.core.internals import _transform_index, concatenate_block_managers from pandas.core.sorting import is_int64_overflow_possible +if TYPE_CHECKING: + from pandas import DataFrame, Series # noqa:F401 + @Substitution("\nleft : DataFrame") @Appender(_merge_doc, indents=0) def merge( left, right, - how="inner", + how: str = "inner", on=None, left_on=None, right_on=None, - left_index=False, - right_index=False, - sort=False, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, suffixes=("_x", "_y"), - copy=True, - indicator=False, + copy: bool = True, + indicator: bool = False, validate=None, ): op = _MergeOperation( @@ -86,7 +91,9 @@ def merge( merge.__doc__ = _merge_doc % "\nleft : DataFrame" -def _groupby_and_merge(by, on, left, right, _merge_pieces, check_duplicates=True): +def _groupby_and_merge( + by, on, left, right, _merge_pieces, check_duplicates: bool = True +): """ groupby & merge; we are always performing a left-by type operation @@ -172,7 +179,7 @@ def merge_ordered( right_by=None, fill_method=None, suffixes=("_x", "_y"), - how="outer", + how: str = "outer", ): """ Perform merge with optional filling/interpolation. @@ -298,14 +305,14 @@ def merge_asof( on=None, left_on=None, right_on=None, - left_index=False, - right_index=False, + left_index: bool = False, + right_index: bool = False, by=None, left_by=None, right_by=None, suffixes=("_x", "_y"), tolerance=None, - allow_exact_matches=True, + allow_exact_matches: bool = True, direction="backward", ): """ @@ -533,33 +540,33 @@ def merge_asof( # TODO: only copy DataFrames when modification necessary class _MergeOperation: """ - Perform a database (SQL) merge operation between two DataFrame objects - using either columns as keys or their row indexes + Perform a database (SQL) merge operation between two DataFrame or Series + objects using either columns as keys or their row indexes """ _merge_type = "merge" def __init__( self, - left, - right, - how="inner", + left: Union["Series", "DataFrame"], + right: Union["Series", "DataFrame"], + how: str = "inner", on=None, left_on=None, right_on=None, axis=1, - left_index=False, - right_index=False, - sort=True, + left_index: bool = False, + right_index: bool = False, + sort: bool = True, suffixes=("_x", "_y"), - copy=True, - indicator=False, + copy: bool = True, + indicator: bool = False, validate=None, ): - left = validate_operand(left) - right = validate_operand(right) - self.left = self.orig_left = left - self.right = self.orig_right = right + _left = _validate_operand(left) + _right = _validate_operand(right) + self.left = self.orig_left = _validate_operand(_left) # type: "DataFrame" + self.right = self.orig_right = _validate_operand(_right) # type: "DataFrame" self.how = how self.axis = axis @@ -577,7 +584,7 @@ def __init__( self.indicator = indicator if isinstance(self.indicator, str): - self.indicator_name = self.indicator + self.indicator_name = self.indicator # type: Optional[str] elif isinstance(self.indicator, bool): self.indicator_name = "_merge" if self.indicator else None else: @@ -597,11 +604,11 @@ def __init__( ) # warn user when merging between different levels - if left.columns.nlevels != right.columns.nlevels: + if _left.columns.nlevels != _right.columns.nlevels: msg = ( "merging between different levels can give an unintended " "result ({left} levels on the left, {right} on the right)" - ).format(left=left.columns.nlevels, right=right.columns.nlevels) + ).format(left=_left.columns.nlevels, right=_right.columns.nlevels) warnings.warn(msg, UserWarning) self._validate_specification() @@ -658,7 +665,9 @@ def get_result(self): return result - def _indicator_pre_merge(self, left, right): + def _indicator_pre_merge( + self, left: "DataFrame", right: "DataFrame" + ) -> Tuple["DataFrame", "DataFrame"]: columns = left.columns.union(right.columns) @@ -878,7 +887,12 @@ def _get_join_info(self): return join_index, left_indexer, right_indexer def _create_join_index( - self, index, other_index, indexer, other_indexer, how="left" + self, + index: Index, + other_index: Index, + indexer, + other_indexer, + how: str = "left", ): """ Create a join index by rearranging one index to match another @@ -1263,7 +1277,9 @@ def _validate(self, validate: str): raise ValueError("Not a valid argument for validate") -def _get_join_indexers(left_keys, right_keys, sort=False, how="inner", **kwargs): +def _get_join_indexers( + left_keys, right_keys, sort: bool = False, how: str = "inner", **kwargs +): """ Parameters @@ -1410,13 +1426,13 @@ def __init__( on=None, left_on=None, right_on=None, - left_index=False, - right_index=False, + left_index: bool = False, + right_index: bool = False, axis=1, suffixes=("_x", "_y"), - copy=True, + copy: bool = True, fill_method=None, - how="outer", + how: str = "outer", ): self.fill_method = fill_method @@ -1508,18 +1524,18 @@ def __init__( on=None, left_on=None, right_on=None, - left_index=False, - right_index=False, + left_index: bool = False, + right_index: bool = False, by=None, left_by=None, right_by=None, axis=1, suffixes=("_x", "_y"), - copy=True, + copy: bool = True, fill_method=None, - how="asof", + how: str = "asof", tolerance=None, - allow_exact_matches=True, + allow_exact_matches: bool = True, direction="backward", ): @@ -1757,13 +1773,15 @@ def flip(xs): return func(left_values, right_values, self.allow_exact_matches, tolerance) -def _get_multiindex_indexer(join_keys, index, sort): +def _get_multiindex_indexer(join_keys, index: MultiIndex, sort: bool): # bind `sort` argument fkeys = partial(_factorize_keys, sort=sort) # left & right join labels and num. of levels at each location - rcodes, lcodes, shape = map(list, zip(*map(fkeys, index.levels, join_keys))) + mapped = (fkeys(index.levels[n], join_keys[n]) for n in range(len(index.levels))) + zipped = zip(*mapped) + rcodes, lcodes, shape = [list(x) for x in zipped] if sort: rcodes = list(map(np.take, rcodes, index.codes)) else: @@ -1791,7 +1809,7 @@ def _get_multiindex_indexer(join_keys, index, sort): return libjoin.left_outer_join(lkey, rkey, count, sort=sort) -def _get_single_indexer(join_key, index, sort=False): +def _get_single_indexer(join_key, index, sort: bool = False): left_key, right_key, count = _factorize_keys(join_key, index, sort=sort) left_indexer, right_indexer = libjoin.left_outer_join( @@ -1801,7 +1819,7 @@ def _get_single_indexer(join_key, index, sort=False): return left_indexer, right_indexer -def _left_join_on_index(left_ax, right_ax, join_keys, sort=False): +def _left_join_on_index(left_ax: Index, right_ax: Index, join_keys, sort: bool = False): if len(join_keys) > 1: if not ( (isinstance(right_ax, MultiIndex) and len(join_keys) == right_ax.nlevels) @@ -1915,7 +1933,7 @@ def _factorize_keys(lk, rk, sort=True): return llab, rlab, count -def _sort_labels(uniques, left, right): +def _sort_labels(uniques: np.ndarray, left, right): if not isinstance(uniques, np.ndarray): # tuplesafe uniques = Index(uniques).values @@ -1930,7 +1948,7 @@ def _sort_labels(uniques, left, right): return new_left, new_right -def _get_join_keys(llab, rlab, shape, sort): +def _get_join_keys(llab, rlab, shape, sort: bool): # how many levels can be done without overflow pred = lambda i: not is_int64_overflow_possible(shape[:i]) @@ -1970,7 +1988,7 @@ def _any(x) -> bool: return x is not None and com.any_not_none(*x) -def validate_operand(obj): +def _validate_operand(obj: FrameOrSeries) -> "DataFrame": if isinstance(obj, ABCDataFrame): return obj elif isinstance(obj, ABCSeries): @@ -1985,7 +2003,7 @@ def validate_operand(obj): ) -def _items_overlap_with_suffix(left, lsuffix, right, rsuffix): +def _items_overlap_with_suffix(left: Index, lsuffix, right: Index, rsuffix): """ If two indices overlap, add suffixes to overlapping entries.