-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
PERF: optimize algos.take for repeated calls #39692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4512f9c
36c3ed2
6d52932
ded773a
2ee2543
f489ba5
96305c5
480d2b4
c70ac4d
9fba887
5273cd5
d3dd4e4
288c6f2
06a3901
ca30487
bf598a7
05b6b87
2284813
4861fdb
a41ee6b
b52e1ec
76371cf
2faf70b
1c19732
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
""" | ||
from __future__ import annotations | ||
|
||
import functools | ||
import operator | ||
from textwrap import dedent | ||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, cast | ||
|
@@ -73,6 +74,9 @@ | |
_shared_docs: Dict[str, str] = {} | ||
|
||
|
||
maybe_promote_cached = functools.lru_cache(maxsize=128)(maybe_promote) | ||
|
||
|
||
# --------------- # | ||
# dtype access # | ||
# --------------- # | ||
|
@@ -1534,40 +1538,52 @@ def _take_nd_object(arr, indexer, out, axis: int, fill_value, mask_info): | |
} | ||
|
||
|
||
@functools.lru_cache(maxsize=128) | ||
def __get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis): | ||
""" | ||
Part of _get_take_nd_function below that doesn't need the mask | ||
and thus can be cached. | ||
""" | ||
tup = (arr_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
return func | ||
|
||
tup = (out_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
func = _convert_wrapper(func, out_dtype) | ||
return func | ||
|
||
return None | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the caching not on this function? having too many levels of indirection is -1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will clarify the comment above, the |
||
def _get_take_nd_function( | ||
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None | ||
): | ||
if ndim <= 2: | ||
tup = (arr_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
return func | ||
|
||
tup = (out_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
func = _convert_wrapper(func, out_dtype) | ||
return func | ||
func = __get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis) | ||
|
||
def func2(arr, indexer, out, fill_value=np.nan): | ||
indexer = ensure_int64(indexer) | ||
_take_nd_object( | ||
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info | ||
) | ||
if func is None: | ||
|
||
return func2 | ||
def func(arr, indexer, out, fill_value=np.nan): | ||
indexer = ensure_int64(indexer) | ||
_take_nd_object( | ||
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info | ||
) | ||
|
||
return func | ||
|
||
|
||
def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None): | ||
|
@@ -1661,6 +1677,40 @@ def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None) | |
return result | ||
|
||
|
||
def _take_preprocess_indexer_and_fill_value( | ||
arr, indexer, axis, out, fill_value, allow_fill | ||
): | ||
mask_info = None | ||
|
||
if indexer is None: | ||
indexer = np.arange(arr.shape[axis], dtype=np.int64) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
else: | ||
indexer = ensure_int64(indexer, copy=False) | ||
if not allow_fill: | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
mask_info = None, False | ||
else: | ||
# check for promotion based on types only (do this first because | ||
# it's faster than computing a mask) | ||
dtype, fill_value = maybe_promote_cached(arr.dtype, fill_value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is also copied verbatim out of |
||
if dtype != arr.dtype and (out is None or out.dtype != dtype): | ||
# check if promotion is actually required based on indexer | ||
mask = indexer == -1 | ||
needs_masking = mask.any() | ||
mask_info = mask, needs_masking | ||
if needs_masking: | ||
if out is not None and out.dtype != dtype: | ||
raise TypeError("Incompatible type for fill_value") | ||
else: | ||
# if not, then depromote, set fill_value to dummy | ||
# (it won't be used but we don't want the cython code | ||
# to crash when trying to cast it to dtype) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
|
||
return indexer, dtype, fill_value, mask_info | ||
|
||
|
||
def take_nd( | ||
arr, | ||
indexer, | ||
|
@@ -1700,8 +1750,6 @@ def take_nd( | |
subarray : array-like | ||
May be the same type as the input, or cast to an ndarray. | ||
""" | ||
mask_info = None | ||
|
||
if fill_value is lib.no_default: | ||
fill_value = na_value_for_dtype(arr.dtype, compat=False) | ||
|
||
|
@@ -1712,31 +1760,9 @@ def take_nd( | |
arr = extract_array(arr) | ||
arr = np.asarray(arr) | ||
|
||
if indexer is None: | ||
indexer = np.arange(arr.shape[axis], dtype=np.int64) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
else: | ||
indexer = ensure_int64(indexer, copy=False) | ||
if not allow_fill: | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
mask_info = None, False | ||
else: | ||
# check for promotion based on types only (do this first because | ||
# it's faster than computing a mask) | ||
dtype, fill_value = maybe_promote(arr.dtype, fill_value) | ||
if dtype != arr.dtype and (out is None or out.dtype != dtype): | ||
# check if promotion is actually required based on indexer | ||
mask = indexer == -1 | ||
needs_masking = mask.any() | ||
mask_info = mask, needs_masking | ||
if needs_masking: | ||
if out is not None and out.dtype != dtype: | ||
raise TypeError("Incompatible type for fill_value") | ||
else: | ||
# if not, then depromote, set fill_value to dummy | ||
# (it won't be used but we don't want the cython code | ||
# to crash when trying to cast it to dtype) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
indexer, dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value( | ||
arr, indexer, axis, out, fill_value, allow_fill | ||
) | ||
|
||
flip_order = False | ||
if arr.ndim == 2 and arr.flags.f_contiguous: | ||
|
@@ -1776,6 +1802,43 @@ def take_nd( | |
take_1d = take_nd | ||
|
||
|
||
def take_1d_array( | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
arr: np.ndarray, | ||
indexer: np.ndarray, | ||
out=None, | ||
fill_value=lib.no_default, | ||
allow_fill: bool = True, | ||
): | ||
""" | ||
Specialized version for 1D arrays. Differences compared to take_nd/take_1d: | ||
|
||
- Assumes input (arr, indexer) has already been converted to numpy arrays | ||
- Only works for 1D arrays | ||
|
||
""" | ||
if fill_value is lib.no_default: | ||
fill_value = na_value_for_dtype(arr.dtype, compat=False) | ||
|
||
if isinstance(arr, ABCExtensionArray): | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Check for EA to catch DatetimeArray, TimedeltaArray | ||
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill) | ||
|
||
indexer, dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value( | ||
arr, indexer, 0, out, fill_value, allow_fill | ||
) | ||
|
||
# at this point, it's guaranteed that dtype can hold both the arr values | ||
# and the fill_value | ||
out = np.empty(indexer.shape, dtype=dtype) | ||
|
||
func = _get_take_nd_function( | ||
arr.ndim, arr.dtype, out.dtype, axis=0, mask_info=mask_info | ||
) | ||
func(arr, indexer, out, fill_value) | ||
|
||
return out | ||
|
||
|
||
def take_2d_multi(arr, indexer, fill_value=np.nan): | ||
""" | ||
Specialized Cython take which sets NaN values in one pass. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this coment does't make any sense w/o the PR context. can you put / move a doc-string here. typing a +1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "and thus can be cached" on the next line is the essential continuation of the sentence.
The mask can be an array, and thus is not hashable and thus cannot be used as argument for a cached function.