30
30
ArrayLike ,
31
31
DtypeObj ,
32
32
IndexLabel ,
33
+ Shape ,
33
34
Suffixes ,
34
35
npt ,
35
36
)
@@ -625,6 +626,9 @@ class _MergeOperation:
625
626
copy : bool
626
627
indicator : bool
627
628
validate : str | None
629
+ join_names : list [Hashable ]
630
+ right_join_keys : list [AnyArrayLike ]
631
+ left_join_keys : list [AnyArrayLike ]
628
632
629
633
def __init__ (
630
634
self ,
@@ -960,9 +964,9 @@ def _maybe_add_join_keys(
960
964
rvals = result [name ]._values
961
965
else :
962
966
# TODO: can we pin down take_right's type earlier?
963
- take_right = extract_array (take_right , extract_numpy = True )
964
- rfill = na_value_for_dtype (take_right .dtype )
965
- rvals = algos .take_nd (take_right , right_indexer , fill_value = rfill )
967
+ taker = extract_array (take_right , extract_numpy = True )
968
+ rfill = na_value_for_dtype (taker .dtype )
969
+ rvals = algos .take_nd (taker , right_indexer , fill_value = rfill )
966
970
967
971
# if we have an all missing left_indexer
968
972
# make sure to just use the right values or vice-versa
@@ -1098,7 +1102,9 @@ def _create_join_index(
1098
1102
index = index .append (Index ([fill_value ]))
1099
1103
return index .take (indexer )
1100
1104
1101
- def _get_merge_keys (self ):
1105
+ def _get_merge_keys (
1106
+ self ,
1107
+ ) -> tuple [list [AnyArrayLike ], list [AnyArrayLike ], list [Hashable ]]:
1102
1108
"""
1103
1109
Note: has side effects (copy/delete key columns)
1104
1110
@@ -1117,8 +1123,8 @@ def _get_merge_keys(self):
1117
1123
left_keys : list [AnyArrayLike ] = []
1118
1124
right_keys : list [AnyArrayLike ] = []
1119
1125
join_names : list [Hashable ] = []
1120
- right_drop = []
1121
- left_drop = []
1126
+ right_drop : list [ Hashable ] = []
1127
+ left_drop : list [ Hashable ] = []
1122
1128
1123
1129
left , right = self .left , self .right
1124
1130
@@ -1168,6 +1174,7 @@ def _get_merge_keys(self):
1168
1174
right_keys .append (right .index )
1169
1175
if lk is not None and lk == rk : # FIXME: what about other NAs?
1170
1176
# avoid key upcast in corner case (length-0)
1177
+ lk = cast (Hashable , lk )
1171
1178
if len (left ) > 0 :
1172
1179
right_drop .append (rk )
1173
1180
else :
@@ -1260,6 +1267,8 @@ def _maybe_coerce_merge_keys(self) -> None:
1260
1267
# if either left or right is a categorical
1261
1268
# then the must match exactly in categories & ordered
1262
1269
if lk_is_cat and rk_is_cat :
1270
+ lk = cast (Categorical , lk )
1271
+ rk = cast (Categorical , rk )
1263
1272
if lk ._categories_match_up_to_permutation (rk ):
1264
1273
continue
1265
1274
@@ -1286,7 +1295,22 @@ def _maybe_coerce_merge_keys(self) -> None:
1286
1295
elif is_integer_dtype (rk .dtype ) and is_float_dtype (lk .dtype ):
1287
1296
# GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int
1288
1297
with np .errstate (invalid = "ignore" ):
1289
- if not (lk == lk .astype (rk .dtype ))[~ np .isnan (lk )].all ():
1298
+ # error: Argument 1 to "astype" of "ndarray" has incompatible
1299
+ # type "Union[ExtensionDtype, Any, dtype[Any]]"; expected
1300
+ # "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]"
1301
+ casted = lk .astype (rk .dtype ) # type: ignore[arg-type]
1302
+
1303
+ # Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has
1304
+ # incompatible type "Union[ExtensionArray, ndarray[Any, Any],
1305
+ # Index, Series]"; expected "Union[_SupportsArray[dtype[Any]],
1306
+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int,
1307
+ # float, complex, str, bytes, _NestedSequence[Union[bool,
1308
+ # int, float, complex, str, bytes]]]"
1309
+ mask = ~ np .isnan (lk ) # type: ignore[arg-type]
1310
+ match = lk == casted
1311
+ # error: Item "ExtensionArray" of "Union[ExtensionArray,
1312
+ # ndarray[Any, Any], Any]" has no attribute "all"
1313
+ if not match [mask ].all (): # type: ignore[union-attr]
1290
1314
warnings .warn (
1291
1315
"You are merging on int and float "
1292
1316
"columns where the float values "
@@ -1299,7 +1323,22 @@ def _maybe_coerce_merge_keys(self) -> None:
1299
1323
elif is_float_dtype (rk .dtype ) and is_integer_dtype (lk .dtype ):
1300
1324
# GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int
1301
1325
with np .errstate (invalid = "ignore" ):
1302
- if not (rk == rk .astype (lk .dtype ))[~ np .isnan (rk )].all ():
1326
+ # error: Argument 1 to "astype" of "ndarray" has incompatible
1327
+ # type "Union[ExtensionDtype, Any, dtype[Any]]"; expected
1328
+ # "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]"
1329
+ casted = rk .astype (lk .dtype ) # type: ignore[arg-type]
1330
+
1331
+ # Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has
1332
+ # incompatible type "Union[ExtensionArray, ndarray[Any, Any],
1333
+ # Index, Series]"; expected "Union[_SupportsArray[dtype[Any]],
1334
+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int,
1335
+ # float, complex, str, bytes, _NestedSequence[Union[bool,
1336
+ # int, float, complex, str, bytes]]]"
1337
+ mask = ~ np .isnan (rk ) # type: ignore[arg-type]
1338
+ match = rk == casted
1339
+ # error: Item "ExtensionArray" of "Union[ExtensionArray,
1340
+ # ndarray[Any, Any], Any]" has no attribute "all"
1341
+ if not match [mask ].all (): # type: ignore[union-attr]
1303
1342
warnings .warn (
1304
1343
"You are merging on int and float "
1305
1344
"columns where the float values "
@@ -1370,11 +1409,11 @@ def _maybe_coerce_merge_keys(self) -> None:
1370
1409
# columns, and end up trying to merge
1371
1410
# incompatible dtypes. See GH 16900.
1372
1411
if name in self .left .columns :
1373
- typ = lk .categories .dtype if lk_is_cat else object
1412
+ typ = cast ( Categorical , lk ) .categories .dtype if lk_is_cat else object
1374
1413
self .left = self .left .copy ()
1375
1414
self .left [name ] = self .left [name ].astype (typ )
1376
1415
if name in self .right .columns :
1377
- typ = rk .categories .dtype if rk_is_cat else object
1416
+ typ = cast ( Categorical , rk ) .categories .dtype if rk_is_cat else object
1378
1417
self .right = self .right .copy ()
1379
1418
self .right [name ] = self .right [name ].astype (typ )
1380
1419
@@ -1592,7 +1631,7 @@ def get_join_indexers(
1592
1631
llab , rlab , shape = (list (x ) for x in zipped )
1593
1632
1594
1633
# get flat i8 keys from label lists
1595
- lkey , rkey = _get_join_keys (llab , rlab , shape , sort )
1634
+ lkey , rkey = _get_join_keys (llab , rlab , tuple ( shape ) , sort )
1596
1635
1597
1636
# factorize keys to a dense i8 space
1598
1637
# `count` is the num. of unique keys
@@ -1922,7 +1961,9 @@ def _validate_left_right_on(self, left_on, right_on):
1922
1961
1923
1962
return left_on , right_on
1924
1963
1925
- def _get_merge_keys (self ):
1964
+ def _get_merge_keys (
1965
+ self ,
1966
+ ) -> tuple [list [AnyArrayLike ], list [AnyArrayLike ], list [Hashable ]]:
1926
1967
1927
1968
# note this function has side effects
1928
1969
(left_join_keys , right_join_keys , join_names ) = super ()._get_merge_keys ()
@@ -1954,7 +1995,8 @@ def _get_merge_keys(self):
1954
1995
if self .tolerance is not None :
1955
1996
1956
1997
if self .left_index :
1957
- lt = self .left .index
1998
+ # Actually more specifically an Index
1999
+ lt = cast (AnyArrayLike , self .left .index )
1958
2000
else :
1959
2001
lt = left_join_keys [- 1 ]
1960
2002
@@ -2069,21 +2111,21 @@ def injection(obj):
2069
2111
2070
2112
# get tuple representation of values if more than one
2071
2113
if len (left_by_values ) == 1 :
2072
- left_by_values = left_by_values [0 ]
2073
- right_by_values = right_by_values [0 ]
2114
+ lbv = left_by_values [0 ]
2115
+ rbv = right_by_values [0 ]
2074
2116
else :
2075
2117
# We get here with non-ndarrays in test_merge_by_col_tz_aware
2076
2118
# and test_merge_groupby_multiple_column_with_categorical_column
2077
- left_by_values = flip (left_by_values )
2078
- right_by_values = flip (right_by_values )
2119
+ lbv = flip (left_by_values )
2120
+ rbv = flip (right_by_values )
2079
2121
2080
2122
# upcast 'by' parameter because HashTable is limited
2081
- by_type = _get_cython_type_upcast (left_by_values .dtype )
2123
+ by_type = _get_cython_type_upcast (lbv .dtype )
2082
2124
by_type_caster = _type_casters [by_type ]
2083
2125
# error: Cannot call function of unknown type
2084
- left_by_values = by_type_caster (left_by_values ) # type: ignore[operator]
2126
+ left_by_values = by_type_caster (lbv ) # type: ignore[operator]
2085
2127
# error: Cannot call function of unknown type
2086
- right_by_values = by_type_caster (right_by_values ) # type: ignore[operator]
2128
+ right_by_values = by_type_caster (rbv ) # type: ignore[operator]
2087
2129
2088
2130
# choose appropriate function by type
2089
2131
func = _asof_by_function (self .direction )
@@ -2139,7 +2181,7 @@ def _get_multiindex_indexer(
2139
2181
rcodes [i ][mask ] = shape [i ] - 1
2140
2182
2141
2183
# get flat i8 join keys
2142
- lkey , rkey = _get_join_keys (lcodes , rcodes , shape , sort )
2184
+ lkey , rkey = _get_join_keys (lcodes , rcodes , tuple ( shape ) , sort )
2143
2185
2144
2186
# factorize keys to a dense i8 space
2145
2187
lkey , rkey , count = _factorize_keys (lkey , rkey , sort = sort )
@@ -2377,7 +2419,12 @@ def _sort_labels(
2377
2419
return new_left , new_right
2378
2420
2379
2421
2380
- def _get_join_keys (llab , rlab , shape , sort : bool ):
2422
+ def _get_join_keys (
2423
+ llab : list [npt .NDArray [np .int64 | np .intp ]],
2424
+ rlab : list [npt .NDArray [np .int64 | np .intp ]],
2425
+ shape : Shape ,
2426
+ sort : bool ,
2427
+ ) -> tuple [npt .NDArray [np .int64 ], npt .NDArray [np .int64 ]]:
2381
2428
2382
2429
# how many levels can be done without overflow
2383
2430
nlev = next (
@@ -2405,7 +2452,7 @@ def _get_join_keys(llab, rlab, shape, sort: bool):
2405
2452
2406
2453
llab = [lkey ] + llab [nlev :]
2407
2454
rlab = [rkey ] + rlab [nlev :]
2408
- shape = [ count ] + shape [nlev :]
2455
+ shape = ( count ,) + shape [nlev :]
2409
2456
2410
2457
return _get_join_keys (llab , rlab , shape , sort )
2411
2458
0 commit comments