Skip to content

TYP: MergeHow + JoinHow #49664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 5, 2022
6 changes: 6 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ def closed(self) -> bool:
# dropna
AnyAll = Literal["any", "all"]

# merge
MergeHow = Literal["left", "right", "inner", "outer", "cross"]

# join
JoinHow = Literal["left", "right", "inner", "outer"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Join supports cross as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typed join with MergeHow.. used JoinHow for the indices.

I actually took this from pandas-stubs

https://github.com/pandas-dev/pandas-stubs/blob/db3b883dbe8d22863c8aca37c89431afce3b57e2/pandas-stubs/_typing.pyi#L299-L300

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check how they type DataFrame.join?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep


MatplotlibColor = Union[str, Sequence[float]]
TimeGrouperOrigin = Union[
"Timestamp", Literal["epoch", "start", "start_day", "end", "end_day"]
Expand Down
8 changes: 5 additions & 3 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
IgnoreRaise,
IndexKeyFunc,
IndexLabel,
JoinHow,
Level,
MergeHow,
NaPosition,
PythonFuncType,
QuantileInterpolation,
Expand Down Expand Up @@ -9524,7 +9526,7 @@ def join(
self,
other: DataFrame | Series | list[DataFrame | Series],
on: IndexLabel | None = None,
how: str = "left",
how: JoinHow = "left",
lsuffix: str = "",
rsuffix: str = "",
sort: bool = False,
Expand Down Expand Up @@ -9697,7 +9699,7 @@ def _join_compat(
self,
other: DataFrame | Series | Iterable[DataFrame | Series],
on: IndexLabel | None = None,
how: str = "left",
how: MergeHow = "left",
lsuffix: str = "",
rsuffix: str = "",
sort: bool = False,
Expand Down Expand Up @@ -9783,7 +9785,7 @@ def _join_compat(
def merge(
self,
right: DataFrame | Series,
how: str = "inner",
how: MergeHow = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
Expand Down
30 changes: 17 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
F,
IgnoreRaise,
IndexLabel,
JoinHow,
Level,
Shape,
npt,
Expand Down Expand Up @@ -218,7 +219,7 @@ def join(
self,
other: Index,
*,
how: str_t = "left",
how: JoinHow = "left",
level=None,
return_indexers: bool = False,
sort: bool = False,
Expand Down Expand Up @@ -4325,7 +4326,7 @@ def join(
self,
other: Index,
*,
how: str_t = ...,
how: JoinHow = ...,
level: Level = ...,
return_indexers: Literal[True],
sort: bool = ...,
Expand All @@ -4337,7 +4338,7 @@ def join(
self,
other: Index,
*,
how: str_t = ...,
how: JoinHow = ...,
level: Level = ...,
return_indexers: Literal[False] = ...,
sort: bool = ...,
Expand All @@ -4349,7 +4350,7 @@ def join(
self,
other: Index,
*,
how: str_t = ...,
how: JoinHow = ...,
level: Level = ...,
return_indexers: bool = ...,
sort: bool = ...,
Expand All @@ -4362,7 +4363,7 @@ def join(
self,
other: Index,
*,
how: str_t = "left",
how: JoinHow = "left",
level: Level = None,
return_indexers: bool = False,
sort: bool = False,
Expand Down Expand Up @@ -4437,7 +4438,8 @@ def join(
return join_index, None, rindexer

if self._join_precedence < other._join_precedence:
how = {"right": "left", "left": "right"}.get(how, how)
flip: dict[JoinHow, JoinHow] = {"right": "left", "left": "right"}
how = flip.get(how, how)
join_index, lidx, ridx = other.join(
self, how=how, level=level, return_indexers=True
)
Expand Down Expand Up @@ -4483,7 +4485,7 @@ def join(

@final
def _join_via_get_indexer(
self, other: Index, how: str_t, sort: bool
self, other: Index, how: JoinHow, sort: bool
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
# Fallback if we do not have any fastpaths available based on
# uniqueness/monotonicity
Expand Down Expand Up @@ -4517,7 +4519,7 @@ def _join_via_get_indexer(
return join_index, lindexer, rindexer

@final
def _join_multi(self, other: Index, how: str_t):
def _join_multi(self, other: Index, how: JoinHow):
from pandas.core.indexes.multi import MultiIndex
from pandas.core.reshape.merge import restore_dropped_levels_multijoin

Expand Down Expand Up @@ -4589,7 +4591,8 @@ def _join_multi(self, other: Index, how: str_t):
self, other = other, self
flip_order = True
# flip if join method is right or left
how = {"right": "left", "left": "right"}.get(how, how)
flip: dict[JoinHow, JoinHow] = {"right": "left", "left": "right"}
how = flip.get(how, how)

level = other.names.index(jl)
result = self._join_level(other, level, how=how)
Expand All @@ -4600,7 +4603,7 @@ def _join_multi(self, other: Index, how: str_t):

@final
def _join_non_unique(
self, other: Index, how: str_t = "left"
self, other: Index, how: JoinHow = "left"
) -> tuple[Index, npt.NDArray[np.intp], npt.NDArray[np.intp]]:
from pandas.core.reshape.merge import get_join_indexers

Expand Down Expand Up @@ -4632,7 +4635,7 @@ def _join_non_unique(

@final
def _join_level(
self, other: Index, level, how: str_t = "left", keep_order: bool = True
self, other: Index, level, how: JoinHow = "left", keep_order: bool = True
) -> tuple[MultiIndex, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
"""
The join method *only* affects the level of the resulting
Expand Down Expand Up @@ -4683,7 +4686,8 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]:
flip_order = not isinstance(self, MultiIndex)
if flip_order:
left, right = right, left
how = {"right": "left", "left": "right"}.get(how, how)
flip: dict[JoinHow, JoinHow] = {"right": "left", "left": "right"}
how = flip.get(how, how)

assert isinstance(left, MultiIndex)

Expand Down Expand Up @@ -4780,7 +4784,7 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]:

@final
def _join_monotonic(
self, other: Index, how: str_t = "left"
self, other: Index, how: JoinHow = "left"
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
# We only get here with matching dtypes and both monotonic increasing
assert other.dtype == self.dtype
Expand Down
35 changes: 23 additions & 12 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
AxisInt,
DtypeObj,
IndexLabel,
JoinHow,
Literal,
MergeHow,
Shape,
Suffixes,
npt,
Expand Down Expand Up @@ -98,7 +101,7 @@
def merge(
left: DataFrame | Series,
right: DataFrame | Series,
how: str = "inner",
how: MergeHow = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
Expand Down Expand Up @@ -197,7 +200,7 @@ def merge_ordered(
right_by=None,
fill_method: str | None = None,
suffixes: Suffixes = ("_x", "_y"),
how: str = "outer",
how: JoinHow = "outer",
) -> DataFrame:
"""
Perform a merge for ordered data with optional filling/interpolation.
Expand Down Expand Up @@ -612,7 +615,7 @@ class _MergeOperation:
"""

_merge_type = "merge"
how: str
how: MergeHow | Literal["asof"]
on: IndexLabel | None
# left_on/right_on may be None when passed, but in validate_specification
# get replaced with non-None.
Expand All @@ -635,7 +638,7 @@ def __init__(
self,
left: DataFrame | Series,
right: DataFrame | Series,
how: str = "inner",
how: MergeHow | Literal["asof"] = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
Expand Down Expand Up @@ -1010,7 +1013,8 @@ def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]
def _get_join_info(
self,
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:

# make mypy happy
assert self.how != "cross"
left_ax = self.left.axes[self.axis]
right_ax = self.right.axes[self.axis]

Expand Down Expand Up @@ -1072,7 +1076,7 @@ def _create_join_index(
index: Index,
other_index: Index,
indexer: npt.NDArray[np.intp],
how: str = "left",
how: JoinHow = "left",
) -> Index:
"""
Create a join index by rearranging one index to match another
Expand Down Expand Up @@ -1406,7 +1410,7 @@ def _maybe_coerce_merge_keys(self) -> None:

def _create_cross_configuration(
self, left: DataFrame, right: DataFrame
) -> tuple[DataFrame, DataFrame, str, str]:
) -> tuple[DataFrame, DataFrame, JoinHow, str]:
"""
Creates the configuration to dispatch the cross operation to inner join,
e.g. adding a join column and resetting parameters. Join column is added
Expand All @@ -1424,7 +1428,7 @@ def _create_cross_configuration(
to join over.
"""
cross_col = f"_cross_{uuid.uuid4()}"
how = "inner"
how: JoinHow = "inner"
return (
left.assign(**{cross_col: 1}),
right.assign(**{cross_col: 1}),
Expand Down Expand Up @@ -1584,7 +1588,11 @@ def _validate(self, validate: str) -> None:


def get_join_indexers(
left_keys, right_keys, sort: bool = False, how: str = "inner", **kwargs
left_keys,
right_keys,
sort: bool = False,
how: MergeHow | Literal["asof"] = "inner",
**kwargs,
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
"""

Expand Down Expand Up @@ -1757,7 +1765,7 @@ def __init__(
axis: AxisInt = 1,
suffixes: Suffixes = ("_x", "_y"),
fill_method: str | None = None,
how: str = "outer",
how: JoinHow | Literal["asof"] = "outer",
) -> None:

self.fill_method = fill_method
Expand Down Expand Up @@ -1847,7 +1855,7 @@ def __init__(
suffixes: Suffixes = ("_x", "_y"),
copy: bool = True,
fill_method: str | None = None,
how: str = "asof",
how: Literal["asof"] = "asof",
tolerance=None,
allow_exact_matches: bool = True,
direction: str = "backward",
Expand Down Expand Up @@ -2256,7 +2264,10 @@ def _left_join_on_index(


def _factorize_keys(
lk: ArrayLike, rk: ArrayLike, sort: bool = True, how: str = "inner"
lk: ArrayLike,
rk: ArrayLike,
sort: bool = True,
how: MergeHow | Literal["asof"] = "inner",
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
"""
Encode left and right keys as enumerated types.
Expand Down