Skip to content

Commit 5f1307e

Browse files
authored
REF: pass arguments to Index._foo_indexer correctly (#41024)
1 parent 4e16e4f commit 5f1307e

File tree

5 files changed

+94
-86
lines changed

5 files changed

+94
-86
lines changed

pandas/core/indexes/base.py

+63-37
Original file line numberDiff line numberDiff line change
@@ -302,23 +302,47 @@ class Index(IndexOpsMixin, PandasObject):
302302
# for why we need to wrap these instead of making them class attributes
303303
# Moreover, cython will choose the appropriate-dtyped sub-function
304304
# given the dtypes of the passed arguments
305-
def _left_indexer_unique(self, left: np.ndarray, right: np.ndarray) -> np.ndarray:
306-
return libjoin.left_join_indexer_unique(left, right)
307305

306+
@final
307+
def _left_indexer_unique(self: _IndexT, other: _IndexT) -> np.ndarray:
308+
# -> np.ndarray[np.intp]
309+
# Caller is responsible for ensuring other.dtype == self.dtype
310+
sv = self._get_join_target()
311+
ov = other._get_join_target()
312+
return libjoin.left_join_indexer_unique(sv, ov)
313+
314+
@final
308315
def _left_indexer(
309-
self, left: np.ndarray, right: np.ndarray
310-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
311-
return libjoin.left_join_indexer(left, right)
316+
self: _IndexT, other: _IndexT
317+
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
318+
# Caller is responsible for ensuring other.dtype == self.dtype
319+
sv = self._get_join_target()
320+
ov = other._get_join_target()
321+
joined_ndarray, lidx, ridx = libjoin.left_join_indexer(sv, ov)
322+
joined = self._from_join_target(joined_ndarray)
323+
return joined, lidx, ridx
312324

325+
@final
313326
def _inner_indexer(
314-
self, left: np.ndarray, right: np.ndarray
315-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
316-
return libjoin.inner_join_indexer(left, right)
327+
self: _IndexT, other: _IndexT
328+
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
329+
# Caller is responsible for ensuring other.dtype == self.dtype
330+
sv = self._get_join_target()
331+
ov = other._get_join_target()
332+
joined_ndarray, lidx, ridx = libjoin.inner_join_indexer(sv, ov)
333+
joined = self._from_join_target(joined_ndarray)
334+
return joined, lidx, ridx
317335

336+
@final
318337
def _outer_indexer(
319-
self, left: np.ndarray, right: np.ndarray
320-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
321-
return libjoin.outer_join_indexer(left, right)
338+
self: _IndexT, other: _IndexT
339+
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
340+
# Caller is responsible for ensuring other.dtype == self.dtype
341+
sv = self._get_join_target()
342+
ov = other._get_join_target()
343+
joined_ndarray, lidx, ridx = libjoin.outer_join_indexer(sv, ov)
344+
joined = self._from_join_target(joined_ndarray)
345+
return joined, lidx, ridx
322346

323347
_typ = "index"
324348
_data: ExtensionArray | np.ndarray
@@ -2965,11 +2989,7 @@ def _union(self, other: Index, sort):
29652989
):
29662990
# Both are unique and monotonic, so can use outer join
29672991
try:
2968-
# error: Argument 1 to "_outer_indexer" of "Index" has incompatible type
2969-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
2970-
# error: Argument 2 to "_outer_indexer" of "Index" has incompatible type
2971-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
2972-
return self._outer_indexer(lvals, rvals)[0] # type: ignore[arg-type]
2992+
return self._outer_indexer(other)[0]
29732993
except (TypeError, IncompatibleFrequency):
29742994
# incomparable objects
29752995
value_list = list(lvals)
@@ -3083,13 +3103,10 @@ def _intersection(self, other: Index, sort=False):
30833103
"""
30843104
# TODO(EA): setops-refactor, clean all this up
30853105
lvals = self._values
3086-
rvals = other._values
30873106

30883107
if self.is_monotonic and other.is_monotonic:
30893108
try:
3090-
# error: Argument 1 to "_inner_indexer" of "Index" has incompatible type
3091-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
3092-
result = self._inner_indexer(lvals, rvals)[0] # type: ignore[arg-type]
3109+
result = self._inner_indexer(other)[0]
30933110
except TypeError:
30943111
pass
30953112
else:
@@ -4088,8 +4105,8 @@ def _join_non_unique(self, other, how="left"):
40884105
# We only get here if dtypes match
40894106
assert self.dtype == other.dtype
40904107

4091-
lvalues = self._get_engine_target()
4092-
rvalues = other._get_engine_target()
4108+
lvalues = self._get_join_target()
4109+
rvalues = other._get_join_target()
40934110

40944111
left_idx, right_idx = get_join_indexers(
40954112
[lvalues], [rvalues], how=how, sort=True
@@ -4102,7 +4119,8 @@ def _join_non_unique(self, other, how="left"):
41024119
mask = left_idx == -1
41034120
np.putmask(join_array, mask, rvalues.take(right_idx))
41044121

4105-
join_index = self._wrap_joined_index(join_array, other)
4122+
join_arraylike = self._from_join_target(join_array)
4123+
join_index = self._wrap_joined_index(join_arraylike, other)
41064124

41074125
return join_index, left_idx, right_idx
41084126

@@ -4260,9 +4278,6 @@ def _join_monotonic(self, other: Index, how="left"):
42604278
ret_index = other if how == "right" else self
42614279
return ret_index, None, None
42624280

4263-
sv = self._get_engine_target()
4264-
ov = other._get_engine_target()
4265-
42664281
ridx: np.ndarray | None
42674282
lidx: np.ndarray | None
42684283

@@ -4271,36 +4286,34 @@ def _join_monotonic(self, other: Index, how="left"):
42714286
if how == "left":
42724287
join_index = self
42734288
lidx = None
4274-
ridx = self._left_indexer_unique(sv, ov)
4289+
ridx = self._left_indexer_unique(other)
42754290
elif how == "right":
42764291
join_index = other
4277-
lidx = self._left_indexer_unique(ov, sv)
4292+
lidx = other._left_indexer_unique(self)
42784293
ridx = None
42794294
elif how == "inner":
4280-
join_array, lidx, ridx = self._inner_indexer(sv, ov)
4295+
join_array, lidx, ridx = self._inner_indexer(other)
42814296
join_index = self._wrap_joined_index(join_array, other)
42824297
elif how == "outer":
4283-
join_array, lidx, ridx = self._outer_indexer(sv, ov)
4298+
join_array, lidx, ridx = self._outer_indexer(other)
42844299
join_index = self._wrap_joined_index(join_array, other)
42854300
else:
42864301
if how == "left":
4287-
join_array, lidx, ridx = self._left_indexer(sv, ov)
4302+
join_array, lidx, ridx = self._left_indexer(other)
42884303
elif how == "right":
4289-
join_array, ridx, lidx = self._left_indexer(ov, sv)
4304+
join_array, ridx, lidx = other._left_indexer(self)
42904305
elif how == "inner":
4291-
join_array, lidx, ridx = self._inner_indexer(sv, ov)
4306+
join_array, lidx, ridx = self._inner_indexer(other)
42924307
elif how == "outer":
4293-
join_array, lidx, ridx = self._outer_indexer(sv, ov)
4308+
join_array, lidx, ridx = self._outer_indexer(other)
42944309

42954310
join_index = self._wrap_joined_index(join_array, other)
42964311

42974312
lidx = None if lidx is None else ensure_platform_int(lidx)
42984313
ridx = None if ridx is None else ensure_platform_int(ridx)
42994314
return join_index, lidx, ridx
43004315

4301-
def _wrap_joined_index(
4302-
self: _IndexT, joined: np.ndarray, other: _IndexT
4303-
) -> _IndexT:
4316+
def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _IndexT:
43044317
assert other.dtype == self.dtype
43054318

43064319
if isinstance(self, ABCMultiIndex):
@@ -4378,6 +4391,19 @@ def _get_engine_target(self) -> np.ndarray:
43784391
# ndarray]", expected "ndarray")
43794392
return self._values # type: ignore[return-value]
43804393

4394+
def _get_join_target(self) -> np.ndarray:
4395+
"""
4396+
Get the ndarray that we will pass to libjoin functions.
4397+
"""
4398+
return self._get_engine_target()
4399+
4400+
def _from_join_target(self, result: np.ndarray) -> ArrayLike:
4401+
"""
4402+
Cast the ndarray returned from one of the libjoin.foo_indexer functions
4403+
back to type(self)._data.
4404+
"""
4405+
return result
4406+
43814407
@doc(IndexOpsMixin._memory_usage)
43824408
def memory_usage(self, deep: bool = False) -> int:
43834409
result = self._memory_usage(deep=deep)

pandas/core/indexes/datetimelike.py

+11-41
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
NaT,
2121
Timedelta,
2222
iNaT,
23-
join as libjoin,
2423
lib,
2524
)
2625
from pandas._libs.tslibs import (
@@ -75,36 +74,6 @@
7574
_T = TypeVar("_T", bound="DatetimeIndexOpsMixin")
7675

7776

78-
def _join_i8_wrapper(joinf, with_indexers: bool = True):
79-
"""
80-
Create the join wrapper methods.
81-
"""
82-
83-
# error: 'staticmethod' used with a non-method
84-
@staticmethod # type: ignore[misc]
85-
def wrapper(left, right):
86-
# Note: these only get called with left.dtype == right.dtype
87-
orig_left = left
88-
89-
left = left.view("i8")
90-
right = right.view("i8")
91-
92-
results = joinf(left, right)
93-
if with_indexers:
94-
95-
join_index, left_indexer, right_indexer = results
96-
if not isinstance(orig_left, np.ndarray):
97-
# When called from Index._intersection/_union, we have the EA
98-
join_index = join_index.view(orig_left._ndarray.dtype)
99-
join_index = orig_left._from_backing_data(join_index)
100-
101-
return join_index, left_indexer, right_indexer
102-
103-
return results
104-
105-
return wrapper
106-
107-
10877
@inherit_names(
10978
["inferred_freq", "_resolution_obj", "resolution"],
11079
DatetimeLikeArrayMixin,
@@ -603,13 +572,6 @@ def insert(self, loc: int, item):
603572
# --------------------------------------------------------------------
604573
# Join/Set Methods
605574

606-
_inner_indexer = _join_i8_wrapper(libjoin.inner_join_indexer)
607-
_outer_indexer = _join_i8_wrapper(libjoin.outer_join_indexer)
608-
_left_indexer = _join_i8_wrapper(libjoin.left_join_indexer)
609-
_left_indexer_unique = _join_i8_wrapper(
610-
libjoin.left_join_indexer_unique, with_indexers=False
611-
)
612-
613575
def _get_join_freq(self, other):
614576
"""
615577
Get the freq to attach to the result of a join operation.
@@ -621,14 +583,22 @@ def _get_join_freq(self, other):
621583
freq = self.freq if self._can_fast_union(other) else None
622584
return freq
623585

624-
def _wrap_joined_index(self, joined: np.ndarray, other):
586+
def _wrap_joined_index(self, joined, other):
625587
assert other.dtype == self.dtype, (other.dtype, self.dtype)
626-
assert joined.dtype == "i8" or joined.dtype == self.dtype, joined.dtype
627-
joined = joined.view(self._data._ndarray.dtype)
628588
result = super()._wrap_joined_index(joined, other)
629589
result._data._freq = self._get_join_freq(other)
630590
return result
631591

592+
def _get_join_target(self) -> np.ndarray:
593+
return self._data._ndarray.view("i8")
594+
595+
def _from_join_target(self, result: np.ndarray):
596+
# view e.g. i8 back to M8[ns]
597+
result = result.view(self._data._ndarray.dtype)
598+
return self._data._from_backing_data(result)
599+
600+
# --------------------------------------------------------------------
601+
632602
@doc(Index._convert_arr_indexer)
633603
def _convert_arr_indexer(self, keyarr):
634604
try:

pandas/core/indexes/extension.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import numpy as np
1313

14+
from pandas._typing import ArrayLike
1415
from pandas.compat.numpy import function as nv
1516
from pandas.errors import AbstractMethodError
1617
from pandas.util._decorators import (
@@ -300,6 +301,11 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
300301
def _get_engine_target(self) -> np.ndarray:
301302
return np.asarray(self._data)
302303

304+
def _from_join_target(self, result: np.ndarray) -> ArrayLike:
305+
# ATM this is only for IntervalIndex, implicit assumption
306+
# about _get_engine_target
307+
return type(self._data)._from_sequence(result, dtype=self.dtype)
308+
303309
def delete(self, loc):
304310
"""
305311
Make new Index with passed location(-s) deleted
@@ -410,6 +416,10 @@ def _simple_new(
410416
def _get_engine_target(self) -> np.ndarray:
411417
return self._data._ndarray
412418

419+
def _from_join_target(self, result: np.ndarray) -> ArrayLike:
420+
assert result.dtype == self._data._ndarray.dtype
421+
return self._data._from_backing_data(result)
422+
413423
def insert(self: _T, loc: int, item) -> Index:
414424
"""
415425
Make new Index inserting new item at location. Follows
@@ -458,7 +468,11 @@ def putmask(self, mask, value) -> Index:
458468

459469
return type(self)._simple_new(res_values, name=self.name)
460470

461-
def _wrap_joined_index(self: _T, joined: np.ndarray, other: _T) -> _T:
471+
# error: Argument 1 of "_wrap_joined_index" is incompatible with supertype
472+
# "Index"; supertype defines the argument type as "Union[ExtensionArray, ndarray]"
473+
def _wrap_joined_index( # type: ignore[override]
474+
self: _T, joined: NDArrayBackedExtensionArray, other: _T
475+
) -> _T:
462476
name = get_op_result_name(self, other)
463-
arr = self._data._from_backing_data(joined)
464-
return type(self)._simple_new(arr, name=name)
477+
478+
return type(self)._simple_new(joined, name=name)

pandas/core/indexes/multi.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3613,14 +3613,12 @@ def _maybe_match_names(self, other):
36133613

36143614
def _intersection(self, other, sort=False) -> MultiIndex:
36153615
other, result_names = self._convert_can_do_setop(other)
3616-
3617-
lvals = self._values
3618-
rvals = other._values.astype(object, copy=False)
3616+
other = other.astype(object, copy=False)
36193617

36203618
uniq_tuples = None # flag whether _inner_indexer was successful
36213619
if self.is_monotonic and other.is_monotonic:
36223620
try:
3623-
inner_tuples = self._inner_indexer(lvals, rvals)[0]
3621+
inner_tuples = self._inner_indexer(other)[0]
36243622
sort = False # inner_tuples is already sorted
36253623
except TypeError:
36263624
pass

pandas/tests/indexes/period/test_join.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TestJoin:
1515
def test_join_outer_indexer(self):
1616
pi = period_range("1/1/2000", "1/20/2000", freq="D")
1717

18-
result = pi._outer_indexer(pi._values, pi._values)
18+
result = pi._outer_indexer(pi)
1919
tm.assert_extension_array_equal(result[0], pi._values)
2020
tm.assert_numpy_array_equal(result[1], np.arange(len(pi), dtype=np.intp))
2121
tm.assert_numpy_array_equal(result[2], np.arange(len(pi), dtype=np.intp))

0 commit comments

Comments
 (0)