Skip to content

PERF: Index.join to maintain cached attributes in more cases #57023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 24, 2024
Merged
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Performance improvements
- Performance improvement in :meth:`DataFrame.join` when left and/or right are non-unique and ``how`` is ``"left"``, ``"right"``, or ``"inner"`` (:issue:`56817`)
- Performance improvement in :meth:`DataFrame.join` with ``how="left"`` or ``how="right"`` and ``sort=True`` (:issue:`56919`)
- Performance improvement in :meth:`DataFrameGroupBy.ffill`, :meth:`DataFrameGroupBy.bfill`, :meth:`SeriesGroupBy.ffill`, and :meth:`SeriesGroupBy.bfill` (:issue:`56902`)
- Performance improvement in :meth:`Index.join` by propagating cached attributes in cases where the result matches one of the inputs (:issue:`57023`)
- Performance improvement in :meth:`Index.take` when ``indices`` is a full range indexer from zero to length of index (:issue:`56806`)
- Performance improvement in :meth:`MultiIndex.equals` for equal length indexes (:issue:`56990`)
- Performance improvement in indexing operations for string dtypes (:issue:`56997`)
Expand Down
12 changes: 5 additions & 7 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8012,19 +8012,17 @@ def _arith_method_with_reindex(self, right: DataFrame, op) -> DataFrame:
left = self

# GH#31623, only operate on shared columns
cols, lcols, rcols = left.columns.join(
right.columns, how="inner", level=None, return_indexers=True
cols, lcol_indexer, rcol_indexer = left.columns.join(
right.columns, how="inner", return_indexers=True
)

new_left = left.iloc[:, lcols]
new_right = right.iloc[:, rcols]
new_left = left if lcol_indexer is None else left.iloc[:, lcol_indexer]
new_right = right if rcol_indexer is None else right.iloc[:, rcol_indexer]
result = op(new_left, new_right)

# Do the join on the columns instead of using left._align_for_op
# to avoid constructing two potentially large/sparse DataFrames
join_columns, _, _ = left.columns.join(
right.columns, how="outer", level=None, return_indexers=True
)
join_columns = left.columns.join(right.columns, how="outer")

if result.columns.has_duplicates:
# Avoid reindexing with a duplicate axis.
Expand Down
54 changes: 34 additions & 20 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5023,7 +5023,9 @@ def _join_monotonic(
ridx = self._left_indexer_unique(other)
else:
join_array, lidx, ridx = self._left_indexer(other)
join_index = self._wrap_joined_index(join_array, other, lidx, ridx, how)
join_index, lidx, ridx = self._wrap_join_result(
join_array, other, lidx, ridx, how
)
elif how == "right":
if self.is_unique:
# We can perform much better than the general case
Expand All @@ -5032,40 +5034,52 @@ def _join_monotonic(
ridx = None
else:
join_array, ridx, lidx = other._left_indexer(self)
join_index = self._wrap_joined_index(join_array, other, lidx, ridx, how)
join_index, lidx, ridx = self._wrap_join_result(
join_array, other, lidx, ridx, how
)
elif how == "inner":
join_array, lidx, ridx = self._inner_indexer(other)
join_index = self._wrap_joined_index(join_array, other, lidx, ridx, how)
join_index, lidx, ridx = self._wrap_join_result(
join_array, other, lidx, ridx, how
)
elif how == "outer":
join_array, lidx, ridx = self._outer_indexer(other)
join_index = self._wrap_joined_index(join_array, other, lidx, ridx, how)
join_index, lidx, ridx = self._wrap_join_result(
join_array, other, lidx, ridx, how
)

lidx = None if lidx is None else ensure_platform_int(lidx)
ridx = None if ridx is None else ensure_platform_int(ridx)
return join_index, lidx, ridx

def _wrap_joined_index(
def _wrap_join_result(
self,
joined: ArrayLike,
other: Self,
lidx: npt.NDArray[np.intp],
ridx: npt.NDArray[np.intp],
lidx: npt.NDArray[np.intp] | None,
ridx: npt.NDArray[np.intp] | None,
how: JoinHow,
) -> Self:
) -> tuple[Self, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
assert other.dtype == self.dtype
names = other.names if how == "right" else self.names
if isinstance(self, ABCMultiIndex):
# error: Incompatible return value type (got "MultiIndex",
# expected "Self")
mask = lidx == -1
join_idx = self.take(lidx)
right = cast("MultiIndex", other.take(ridx))
join_index = join_idx.putmask(mask, right)._sort_levels_monotonic()
return join_index.set_names(names) # type: ignore[return-value]

if lidx is not None and lib.is_range_indexer(lidx, len(self)):
lidx = None
if ridx is not None and lib.is_range_indexer(ridx, len(other)):
ridx = None

# return self or other if possible to maintain cached attributes
if lidx is None:
join_index = self
elif ridx is None:
join_index = other
else:
return self._constructor._with_infer(
joined, name=names[0], dtype=self.dtype
)
join_index = self._constructor._with_infer(joined, dtype=self.dtype)

names = other.names if how == "right" else self.names
if join_index.names != names:
join_index = join_index.set_names(names)

return join_index, lidx, ridx

@final
@cache_readonly
Expand Down
16 changes: 9 additions & 7 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,18 +735,20 @@ def _get_join_freq(self, other):
freq = self.freq
return freq

def _wrap_joined_index(
def _wrap_join_result(
self,
joined,
other,
lidx: npt.NDArray[np.intp],
ridx: npt.NDArray[np.intp],
lidx: npt.NDArray[np.intp] | None,
ridx: npt.NDArray[np.intp] | None,
how: JoinHow,
):
) -> tuple[Self, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
assert other.dtype == self.dtype, (other.dtype, self.dtype)
result = super()._wrap_joined_index(joined, other, lidx, ridx, how)
result._data._freq = self._get_join_freq(other)
return result
join_index, lidx, ridx = super()._wrap_join_result(
joined, other, lidx, ridx, how
)
join_index._data._freq = self._get_join_freq(other)
return join_index, lidx, ridx

def _get_engine_target(self) -> np.ndarray:
# engine methods and libjoin methods need dt64/td64 values cast to i8
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/indexes/multi/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def test_join_level(idx, other, join_type):

assert join_index.equals(join_index2)
tm.assert_numpy_array_equal(lidx, lidx2)
tm.assert_numpy_array_equal(ridx, ridx2)
if ridx is None:
assert ridx == ridx2
else:
tm.assert_numpy_array_equal(ridx, ridx2)
tm.assert_numpy_array_equal(join_index2.values, exp_values)


Expand Down