Skip to content

PERF: optimize algos.take for repeated calls #39692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4512f9c
PERF: optimize algos.take for repeated calls
jorisvandenbossche Feb 9, 2021
36c3ed2
fix nd check + fix cache differentiation of int / bool
jorisvandenbossche Feb 9, 2021
6d52932
fix non-scalar fill_value case
jorisvandenbossche Feb 9, 2021
ded773a
fix mypy
jorisvandenbossche Feb 9, 2021
2ee2543
try fix mypy
jorisvandenbossche Feb 9, 2021
f489ba5
fix annotation
jorisvandenbossche Feb 10, 2021
96305c5
improve docstrings
jorisvandenbossche Feb 10, 2021
480d2b4
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Feb 10, 2021
c70ac4d
faster EA check
jorisvandenbossche Feb 10, 2021
9fba887
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Feb 11, 2021
5273cd5
rename take_1d_array to take_1d
jorisvandenbossche Feb 11, 2021
d3dd4e4
add comment about being useful for array manager
jorisvandenbossche Feb 11, 2021
288c6f2
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Feb 15, 2021
06a3901
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 2, 2021
ca30487
use take_nd for now
jorisvandenbossche Mar 2, 2021
bf598a7
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 2, 2021
05b6b87
move caching of maybe_promote to cast.py
jorisvandenbossche Mar 2, 2021
2284813
move type comment
jorisvandenbossche Mar 2, 2021
4861fdb
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 2, 2021
a41ee6b
typo
jorisvandenbossche Mar 2, 2021
b52e1ec
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 4, 2021
76371cf
ensure deprecation warning is always raised
jorisvandenbossche Mar 4, 2021
2faf70b
single underscore
jorisvandenbossche Mar 4, 2021
1c19732
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions pandas/core/array_algos/take.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -177,41 +178,60 @@ def take_2d_multi(
return out


@functools.lru_cache(maxsize=128)
def _get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis):
"""
Part of _get_take_nd_function below that doesn't need `mask_info` and thus
can be cached (mask_info potentially contains a numpy ndarray which is not
hashable and thus cannot be used as argument for cached function).
"""
tup = (arr_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
return func

tup = (out_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
func = _convert_wrapper(func, out_dtype)
return func

return None


def _get_take_nd_function(
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int = 0, mask_info=None
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None
):

"""
Get the appropriate "take" implementation for the given dimension, axis
and dtypes.
"""
func = None
if ndim <= 2:
tup = (arr_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
return func

tup = (out_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
func = _convert_wrapper(func, out_dtype)
return func
# for this part we don't need `mask_info` -> use the cached algo lookup
func = _get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis)

def func2(arr, indexer, out, fill_value=np.nan):
indexer = ensure_int64(indexer)
_take_nd_object(
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info
)
if func is None:

def func(arr, indexer, out, fill_value=np.nan):
indexer = ensure_int64(indexer)
_take_nd_object(
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info
)

return func2
return func


def _view_wrapper(f, arr_dtype=None, out_dtype=None, fill_wrap=None):
Expand Down
32 changes: 31 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
datetime,
timedelta,
)
import functools
import inspect
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -568,6 +569,35 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan):
ValueError
If fill_value is a non-scalar and dtype is not object.
"""
# TODO(2.0): need to directly use the non-cached version as long as we
# possibly raise a deprecation warning for datetime dtype
if dtype.kind == "M":
return _maybe_promote(dtype, fill_value)
Comment on lines +572 to +575
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit unfortunate, but to ensure the warning is always shown, we can't use the cached version for datetime data.

I check what would be the fastest option. The most specific check would be if isinstance(fill_value, date) and not isinstance(fill_value, datetime), but if dtype.kind == "M" is a bit faster.
So the trade-off was between faster for all non-M8 dtypes vs faster for M8 (by being able to use the cached version in most cases) but a bit slower for all other dtypes. So I went with the first (fastest for numeric dtypes).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how big the is the perf tradeoff?

since stacklevels are a constant hassle, one option would be to take the find_stacklevel function and change it so that instead of hard-coding "astype" it just looks for the first call that isn't from inside (non-test) pandas

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the stacklevel as such, it's the warning itself. With caching, it occurs only once, while otherwise this warning is raised every time you use it.

The other option would be to check for this case / raise the warning a level higher up (so eg the line we are commenting up), so that other cases still use the cached version.

# for performance, we are using a cached version of the actual implementation
# of the function in _maybe_promote. However, this doesn't always work (in case
# of non-hashable arguments), so we fallback to the actual implementation if needed
try:
# error: Argument 3 to "__call__" of "_lru_cache_wrapper" has incompatible type
# "Type[Any]"; expected "Hashable" [arg-type]
return _maybe_promote_cached(
dtype, fill_value, type(fill_value) # type: ignore[arg-type]
)
except TypeError:
# if fill_value is not hashable (required for caching)
return _maybe_promote(dtype, fill_value)


@functools.lru_cache(maxsize=128)
def _maybe_promote_cached(dtype, fill_value, fill_value_type):
# The cached version of _maybe_promote below
# This also use fill_value_type as (unused) argument to use this in the
# cache lookup -> to differentiate 1 and True
return _maybe_promote(dtype, fill_value)


def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
# The actual implementation of the function, use `maybe_promote` above for
# a cached version.
if not is_scalar(fill_value):
# with object dtype there is nothing to promote, and the user can
# pass pretty much any weird fill_value they like
Expand Down Expand Up @@ -618,7 +648,7 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan):
"dtype is deprecated. In a future version, this will be cast "
"to object dtype. Pass `fill_value=Timestamp(date_obj)` instead.",
FutureWarning,
stacklevel=7,
stacklevel=8,
)
return dtype, fv
elif isinstance(fill_value, str):
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/internals/array_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)

import pandas.core.algorithms as algos
from pandas.core.array_algos.take import take_nd
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
Expand Down Expand Up @@ -1005,7 +1006,7 @@ def unstack(self, unstacker, fill_value) -> ArrayManager:
new_arrays = []
for arr in self.arrays:
for i in range(unstacker.full_shape[1]):
new_arr = algos.take(
new_arr = take_nd(
arr, new_indexer2D[:, i], allow_fill=True, fill_value=fill_value
)
new_arrays.append(new_arr)
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/test_take.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ def test_take_axis_1(self):
with pytest.raises(IndexError, match="indices are out-of-bounds"):
algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0)

def test_take_non_hashable_fill_value(self):
arr = np.array([1, 2, 3])
indexer = np.array([1, -1])
with pytest.raises(ValueError, match="fill_value must be a scalar"):
algos.take(arr, indexer, allow_fill=True, fill_value=[1])

# with object dtype it is allowed
arr = np.array([1, 2, 3], dtype=object)
result = algos.take(arr, indexer, allow_fill=True, fill_value=[1])
expected = np.array([2, [1]], dtype=object)
tm.assert_numpy_array_equal(result, expected)


class TestExtensionTake:
# The take method found in pd.api.extensions
Expand Down