Skip to content

Commit d8fb17c

Browse files
authored
TYP: reshape.merge (#48590)
* TYP: _get_join_keys * TYP: reshape.merge
1 parent 438b957 commit d8fb17c

File tree

1 file changed

+70
-23
lines changed

1 file changed

+70
-23
lines changed

pandas/core/reshape/merge.py

+70-23
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ArrayLike,
3131
DtypeObj,
3232
IndexLabel,
33+
Shape,
3334
Suffixes,
3435
npt,
3536
)
@@ -625,6 +626,9 @@ class _MergeOperation:
625626
copy: bool
626627
indicator: bool
627628
validate: str | None
629+
join_names: list[Hashable]
630+
right_join_keys: list[AnyArrayLike]
631+
left_join_keys: list[AnyArrayLike]
628632

629633
def __init__(
630634
self,
@@ -960,9 +964,9 @@ def _maybe_add_join_keys(
960964
rvals = result[name]._values
961965
else:
962966
# TODO: can we pin down take_right's type earlier?
963-
take_right = extract_array(take_right, extract_numpy=True)
964-
rfill = na_value_for_dtype(take_right.dtype)
965-
rvals = algos.take_nd(take_right, right_indexer, fill_value=rfill)
967+
taker = extract_array(take_right, extract_numpy=True)
968+
rfill = na_value_for_dtype(taker.dtype)
969+
rvals = algos.take_nd(taker, right_indexer, fill_value=rfill)
966970

967971
# if we have an all missing left_indexer
968972
# make sure to just use the right values or vice-versa
@@ -1098,7 +1102,9 @@ def _create_join_index(
10981102
index = index.append(Index([fill_value]))
10991103
return index.take(indexer)
11001104

1101-
def _get_merge_keys(self):
1105+
def _get_merge_keys(
1106+
self,
1107+
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:
11021108
"""
11031109
Note: has side effects (copy/delete key columns)
11041110
@@ -1117,8 +1123,8 @@ def _get_merge_keys(self):
11171123
left_keys: list[AnyArrayLike] = []
11181124
right_keys: list[AnyArrayLike] = []
11191125
join_names: list[Hashable] = []
1120-
right_drop = []
1121-
left_drop = []
1126+
right_drop: list[Hashable] = []
1127+
left_drop: list[Hashable] = []
11221128

11231129
left, right = self.left, self.right
11241130

@@ -1168,6 +1174,7 @@ def _get_merge_keys(self):
11681174
right_keys.append(right.index)
11691175
if lk is not None and lk == rk: # FIXME: what about other NAs?
11701176
# avoid key upcast in corner case (length-0)
1177+
lk = cast(Hashable, lk)
11711178
if len(left) > 0:
11721179
right_drop.append(rk)
11731180
else:
@@ -1260,6 +1267,8 @@ def _maybe_coerce_merge_keys(self) -> None:
12601267
# if either left or right is a categorical
12611268
# then the must match exactly in categories & ordered
12621269
if lk_is_cat and rk_is_cat:
1270+
lk = cast(Categorical, lk)
1271+
rk = cast(Categorical, rk)
12631272
if lk._categories_match_up_to_permutation(rk):
12641273
continue
12651274

@@ -1286,7 +1295,22 @@ def _maybe_coerce_merge_keys(self) -> None:
12861295
elif is_integer_dtype(rk.dtype) and is_float_dtype(lk.dtype):
12871296
# GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int
12881297
with np.errstate(invalid="ignore"):
1289-
if not (lk == lk.astype(rk.dtype))[~np.isnan(lk)].all():
1298+
# error: Argument 1 to "astype" of "ndarray" has incompatible
1299+
# type "Union[ExtensionDtype, Any, dtype[Any]]"; expected
1300+
# "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]"
1301+
casted = lk.astype(rk.dtype) # type: ignore[arg-type]
1302+
1303+
# Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has
1304+
# incompatible type "Union[ExtensionArray, ndarray[Any, Any],
1305+
# Index, Series]"; expected "Union[_SupportsArray[dtype[Any]],
1306+
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int,
1307+
# float, complex, str, bytes, _NestedSequence[Union[bool,
1308+
# int, float, complex, str, bytes]]]"
1309+
mask = ~np.isnan(lk) # type: ignore[arg-type]
1310+
match = lk == casted
1311+
# error: Item "ExtensionArray" of "Union[ExtensionArray,
1312+
# ndarray[Any, Any], Any]" has no attribute "all"
1313+
if not match[mask].all(): # type: ignore[union-attr]
12901314
warnings.warn(
12911315
"You are merging on int and float "
12921316
"columns where the float values "
@@ -1299,7 +1323,22 @@ def _maybe_coerce_merge_keys(self) -> None:
12991323
elif is_float_dtype(rk.dtype) and is_integer_dtype(lk.dtype):
13001324
# GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int
13011325
with np.errstate(invalid="ignore"):
1302-
if not (rk == rk.astype(lk.dtype))[~np.isnan(rk)].all():
1326+
# error: Argument 1 to "astype" of "ndarray" has incompatible
1327+
# type "Union[ExtensionDtype, Any, dtype[Any]]"; expected
1328+
# "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]"
1329+
casted = rk.astype(lk.dtype) # type: ignore[arg-type]
1330+
1331+
# Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has
1332+
# incompatible type "Union[ExtensionArray, ndarray[Any, Any],
1333+
# Index, Series]"; expected "Union[_SupportsArray[dtype[Any]],
1334+
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int,
1335+
# float, complex, str, bytes, _NestedSequence[Union[bool,
1336+
# int, float, complex, str, bytes]]]"
1337+
mask = ~np.isnan(rk) # type: ignore[arg-type]
1338+
match = rk == casted
1339+
# error: Item "ExtensionArray" of "Union[ExtensionArray,
1340+
# ndarray[Any, Any], Any]" has no attribute "all"
1341+
if not match[mask].all(): # type: ignore[union-attr]
13031342
warnings.warn(
13041343
"You are merging on int and float "
13051344
"columns where the float values "
@@ -1370,11 +1409,11 @@ def _maybe_coerce_merge_keys(self) -> None:
13701409
# columns, and end up trying to merge
13711410
# incompatible dtypes. See GH 16900.
13721411
if name in self.left.columns:
1373-
typ = lk.categories.dtype if lk_is_cat else object
1412+
typ = cast(Categorical, lk).categories.dtype if lk_is_cat else object
13741413
self.left = self.left.copy()
13751414
self.left[name] = self.left[name].astype(typ)
13761415
if name in self.right.columns:
1377-
typ = rk.categories.dtype if rk_is_cat else object
1416+
typ = cast(Categorical, rk).categories.dtype if rk_is_cat else object
13781417
self.right = self.right.copy()
13791418
self.right[name] = self.right[name].astype(typ)
13801419

@@ -1592,7 +1631,7 @@ def get_join_indexers(
15921631
llab, rlab, shape = (list(x) for x in zipped)
15931632

15941633
# get flat i8 keys from label lists
1595-
lkey, rkey = _get_join_keys(llab, rlab, shape, sort)
1634+
lkey, rkey = _get_join_keys(llab, rlab, tuple(shape), sort)
15961635

15971636
# factorize keys to a dense i8 space
15981637
# `count` is the num. of unique keys
@@ -1922,7 +1961,9 @@ def _validate_left_right_on(self, left_on, right_on):
19221961

19231962
return left_on, right_on
19241963

1925-
def _get_merge_keys(self):
1964+
def _get_merge_keys(
1965+
self,
1966+
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:
19261967

19271968
# note this function has side effects
19281969
(left_join_keys, right_join_keys, join_names) = super()._get_merge_keys()
@@ -1954,7 +1995,8 @@ def _get_merge_keys(self):
19541995
if self.tolerance is not None:
19551996

19561997
if self.left_index:
1957-
lt = self.left.index
1998+
# Actually more specifically an Index
1999+
lt = cast(AnyArrayLike, self.left.index)
19582000
else:
19592001
lt = left_join_keys[-1]
19602002

@@ -2069,21 +2111,21 @@ def injection(obj):
20692111

20702112
# get tuple representation of values if more than one
20712113
if len(left_by_values) == 1:
2072-
left_by_values = left_by_values[0]
2073-
right_by_values = right_by_values[0]
2114+
lbv = left_by_values[0]
2115+
rbv = right_by_values[0]
20742116
else:
20752117
# We get here with non-ndarrays in test_merge_by_col_tz_aware
20762118
# and test_merge_groupby_multiple_column_with_categorical_column
2077-
left_by_values = flip(left_by_values)
2078-
right_by_values = flip(right_by_values)
2119+
lbv = flip(left_by_values)
2120+
rbv = flip(right_by_values)
20792121

20802122
# upcast 'by' parameter because HashTable is limited
2081-
by_type = _get_cython_type_upcast(left_by_values.dtype)
2123+
by_type = _get_cython_type_upcast(lbv.dtype)
20822124
by_type_caster = _type_casters[by_type]
20832125
# error: Cannot call function of unknown type
2084-
left_by_values = by_type_caster(left_by_values) # type: ignore[operator]
2126+
left_by_values = by_type_caster(lbv) # type: ignore[operator]
20852127
# error: Cannot call function of unknown type
2086-
right_by_values = by_type_caster(right_by_values) # type: ignore[operator]
2128+
right_by_values = by_type_caster(rbv) # type: ignore[operator]
20872129

20882130
# choose appropriate function by type
20892131
func = _asof_by_function(self.direction)
@@ -2139,7 +2181,7 @@ def _get_multiindex_indexer(
21392181
rcodes[i][mask] = shape[i] - 1
21402182

21412183
# get flat i8 join keys
2142-
lkey, rkey = _get_join_keys(lcodes, rcodes, shape, sort)
2184+
lkey, rkey = _get_join_keys(lcodes, rcodes, tuple(shape), sort)
21432185

21442186
# factorize keys to a dense i8 space
21452187
lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort)
@@ -2377,7 +2419,12 @@ def _sort_labels(
23772419
return new_left, new_right
23782420

23792421

2380-
def _get_join_keys(llab, rlab, shape, sort: bool):
2422+
def _get_join_keys(
2423+
llab: list[npt.NDArray[np.int64 | np.intp]],
2424+
rlab: list[npt.NDArray[np.int64 | np.intp]],
2425+
shape: Shape,
2426+
sort: bool,
2427+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]:
23812428

23822429
# how many levels can be done without overflow
23832430
nlev = next(
@@ -2405,7 +2452,7 @@ def _get_join_keys(llab, rlab, shape, sort: bool):
24052452

24062453
llab = [lkey] + llab[nlev:]
24072454
rlab = [rkey] + rlab[nlev:]
2408-
shape = [count] + shape[nlev:]
2455+
shape = (count,) + shape[nlev:]
24092456

24102457
return _get_join_keys(llab, rlab, shape, sort)
24112458

0 commit comments

Comments
 (0)