Skip to content

Commit 7ad4033

Browse files
jbrockmendelyeshsurya
authored andcommitted
REF: pass arguments to Index._foo_indexer correctly (pandas-dev#41024)
1 parent 060364b commit 7ad4033

File tree

1 file changed

+63
-37
lines changed

1 file changed

+63
-37
lines changed

pandas/core/indexes/base.py

+63-37
Original file line numberDiff line numberDiff line change
@@ -302,23 +302,47 @@ class Index(IndexOpsMixin, PandasObject):
302302
# for why we need to wrap these instead of making them class attributes
303303
# Moreover, cython will choose the appropriate-dtyped sub-function
304304
# given the dtypes of the passed arguments
305-
def _left_indexer_unique(self, left: np.ndarray, right: np.ndarray) -> np.ndarray:
306-
return libjoin.left_join_indexer_unique(left, right)
307305

306+
@final
307+
def _left_indexer_unique(self: _IndexT, other: _IndexT) -> np.ndarray:
308+
# -> np.ndarray[np.intp]
309+
# Caller is responsible for ensuring other.dtype == self.dtype
310+
sv = self._get_join_target()
311+
ov = other._get_join_target()
312+
return libjoin.left_join_indexer_unique(sv, ov)
313+
314+
@final
308315
def _left_indexer(
309-
self, left: np.ndarray, right: np.ndarray
310-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
311-
return libjoin.left_join_indexer(left, right)
316+
self: _IndexT, other: _IndexT
317+
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
318+
# Caller is responsible for ensuring other.dtype == self.dtype
319+
sv = self._get_join_target()
320+
ov = other._get_join_target()
321+
joined_ndarray, lidx, ridx = libjoin.left_join_indexer(sv, ov)
322+
joined = self._from_join_target(joined_ndarray)
323+
return joined, lidx, ridx
312324

325+
@final
313326
def _inner_indexer(
314-
self, left: np.ndarray, right: np.ndarray
315-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
316-
return libjoin.inner_join_indexer(left, right)
327+
self: _IndexT, other: _IndexT
328+
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
329+
# Caller is responsible for ensuring other.dtype == self.dtype
330+
sv = self._get_join_target()
331+
ov = other._get_join_target()
332+
joined_ndarray, lidx, ridx = libjoin.inner_join_indexer(sv, ov)
333+
joined = self._from_join_target(joined_ndarray)
334+
return joined, lidx, ridx
317335

336+
@final
318337
def _outer_indexer(
319-
self, left: np.ndarray, right: np.ndarray
320-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
321-
return libjoin.outer_join_indexer(left, right)
338+
self: _IndexT, other: _IndexT
339+
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
340+
# Caller is responsible for ensuring other.dtype == self.dtype
341+
sv = self._get_join_target()
342+
ov = other._get_join_target()
343+
joined_ndarray, lidx, ridx = libjoin.outer_join_indexer(sv, ov)
344+
joined = self._from_join_target(joined_ndarray)
345+
return joined, lidx, ridx
322346

323347
_typ = "index"
324348
_data: ExtensionArray | np.ndarray
@@ -2965,11 +2989,7 @@ def _union(self, other: Index, sort):
29652989
):
29662990
# Both are unique and monotonic, so can use outer join
29672991
try:
2968-
# error: Argument 1 to "_outer_indexer" of "Index" has incompatible type
2969-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
2970-
# error: Argument 2 to "_outer_indexer" of "Index" has incompatible type
2971-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
2972-
return self._outer_indexer(lvals, rvals)[0] # type: ignore[arg-type]
2992+
return self._outer_indexer(other)[0]
29732993
except (TypeError, IncompatibleFrequency):
29742994
# incomparable objects
29752995
value_list = list(lvals)
@@ -3083,13 +3103,10 @@ def _intersection(self, other: Index, sort=False):
30833103
"""
30843104
# TODO(EA): setops-refactor, clean all this up
30853105
lvals = self._values
3086-
rvals = other._values
30873106

30883107
if self.is_monotonic and other.is_monotonic:
30893108
try:
3090-
# error: Argument 1 to "_inner_indexer" of "Index" has incompatible type
3091-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
3092-
result = self._inner_indexer(lvals, rvals)[0] # type: ignore[arg-type]
3109+
result = self._inner_indexer(other)[0]
30933110
except TypeError:
30943111
pass
30953112
else:
@@ -4088,8 +4105,8 @@ def _join_non_unique(self, other, how="left"):
40884105
# We only get here if dtypes match
40894106
assert self.dtype == other.dtype
40904107

4091-
lvalues = self._get_engine_target()
4092-
rvalues = other._get_engine_target()
4108+
lvalues = self._get_join_target()
4109+
rvalues = other._get_join_target()
40934110

40944111
left_idx, right_idx = get_join_indexers(
40954112
[lvalues], [rvalues], how=how, sort=True
@@ -4102,7 +4119,8 @@ def _join_non_unique(self, other, how="left"):
41024119
mask = left_idx == -1
41034120
np.putmask(join_array, mask, rvalues.take(right_idx))
41044121

4105-
join_index = self._wrap_joined_index(join_array, other)
4122+
join_arraylike = self._from_join_target(join_array)
4123+
join_index = self._wrap_joined_index(join_arraylike, other)
41064124

41074125
return join_index, left_idx, right_idx
41084126

@@ -4260,9 +4278,6 @@ def _join_monotonic(self, other: Index, how="left"):
42604278
ret_index = other if how == "right" else self
42614279
return ret_index, None, None
42624280

4263-
sv = self._get_engine_target()
4264-
ov = other._get_engine_target()
4265-
42664281
ridx: np.ndarray | None
42674282
lidx: np.ndarray | None
42684283

@@ -4271,36 +4286,34 @@ def _join_monotonic(self, other: Index, how="left"):
42714286
if how == "left":
42724287
join_index = self
42734288
lidx = None
4274-
ridx = self._left_indexer_unique(sv, ov)
4289+
ridx = self._left_indexer_unique(other)
42754290
elif how == "right":
42764291
join_index = other
4277-
lidx = self._left_indexer_unique(ov, sv)
4292+
lidx = other._left_indexer_unique(self)
42784293
ridx = None
42794294
elif how == "inner":
4280-
join_array, lidx, ridx = self._inner_indexer(sv, ov)
4295+
join_array, lidx, ridx = self._inner_indexer(other)
42814296
join_index = self._wrap_joined_index(join_array, other)
42824297
elif how == "outer":
4283-
join_array, lidx, ridx = self._outer_indexer(sv, ov)
4298+
join_array, lidx, ridx = self._outer_indexer(other)
42844299
join_index = self._wrap_joined_index(join_array, other)
42854300
else:
42864301
if how == "left":
4287-
join_array, lidx, ridx = self._left_indexer(sv, ov)
4302+
join_array, lidx, ridx = self._left_indexer(other)
42884303
elif how == "right":
4289-
join_array, ridx, lidx = self._left_indexer(ov, sv)
4304+
join_array, ridx, lidx = other._left_indexer(self)
42904305
elif how == "inner":
4291-
join_array, lidx, ridx = self._inner_indexer(sv, ov)
4306+
join_array, lidx, ridx = self._inner_indexer(other)
42924307
elif how == "outer":
4293-
join_array, lidx, ridx = self._outer_indexer(sv, ov)
4308+
join_array, lidx, ridx = self._outer_indexer(other)
42944309

42954310
join_index = self._wrap_joined_index(join_array, other)
42964311

42974312
lidx = None if lidx is None else ensure_platform_int(lidx)
42984313
ridx = None if ridx is None else ensure_platform_int(ridx)
42994314
return join_index, lidx, ridx
43004315

4301-
def _wrap_joined_index(
4302-
self: _IndexT, joined: np.ndarray, other: _IndexT
4303-
) -> _IndexT:
4316+
def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _IndexT:
43044317
assert other.dtype == self.dtype
43054318

43064319
if isinstance(self, ABCMultiIndex):
@@ -4378,6 +4391,19 @@ def _get_engine_target(self) -> np.ndarray:
43784391
# ndarray]", expected "ndarray")
43794392
return self._values # type: ignore[return-value]
43804393

4394+
def _get_join_target(self) -> np.ndarray:
4395+
"""
4396+
Get the ndarray that we will pass to libjoin functions.
4397+
"""
4398+
return self._get_engine_target()
4399+
4400+
def _from_join_target(self, result: np.ndarray) -> ArrayLike:
4401+
"""
4402+
Cast the ndarray returned from one of the libjoin.foo_indexer functions
4403+
back to type(self)._data.
4404+
"""
4405+
return result
4406+
43814407
@doc(IndexOpsMixin._memory_usage)
43824408
def memory_usage(self, deep: bool = False) -> int:
43834409
result = self._memory_usage(deep=deep)

0 commit comments

Comments
 (0)