diff --git a/pandas/core/array_algos/take.py b/pandas/core/array_algos/take.py index 447167aa8fae5..2179085d9ad9b 100644 --- a/pandas/core/array_algos/take.py +++ b/pandas/core/array_algos/take.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools from typing import Optional import numpy as np @@ -177,41 +178,60 @@ def take_2d_multi( return out +@functools.lru_cache(maxsize=128) +def _get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis): + """ + Part of _get_take_nd_function below that doesn't need `mask_info` and thus + can be cached (mask_info potentially contains a numpy ndarray which is not + hashable and thus cannot be used as argument for cached function). + """ + tup = (arr_dtype.name, out_dtype.name) + if ndim == 1: + func = _take_1d_dict.get(tup, None) + elif ndim == 2: + if axis == 0: + func = _take_2d_axis0_dict.get(tup, None) + else: + func = _take_2d_axis1_dict.get(tup, None) + if func is not None: + return func + + tup = (out_dtype.name, out_dtype.name) + if ndim == 1: + func = _take_1d_dict.get(tup, None) + elif ndim == 2: + if axis == 0: + func = _take_2d_axis0_dict.get(tup, None) + else: + func = _take_2d_axis1_dict.get(tup, None) + if func is not None: + func = _convert_wrapper(func, out_dtype) + return func + + return None + + def _get_take_nd_function( - ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int = 0, mask_info=None + ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None ): - + """ + Get the appropriate "take" implementation for the given dimension, axis + and dtypes. + """ + func = None if ndim <= 2: - tup = (arr_dtype.name, out_dtype.name) - if ndim == 1: - func = _take_1d_dict.get(tup, None) - elif ndim == 2: - if axis == 0: - func = _take_2d_axis0_dict.get(tup, None) - else: - func = _take_2d_axis1_dict.get(tup, None) - if func is not None: - return func - - tup = (out_dtype.name, out_dtype.name) - if ndim == 1: - func = _take_1d_dict.get(tup, None) - elif ndim == 2: - if axis == 0: - func = _take_2d_axis0_dict.get(tup, None) - else: - func = _take_2d_axis1_dict.get(tup, None) - if func is not None: - func = _convert_wrapper(func, out_dtype) - return func + # for this part we don't need `mask_info` -> use the cached algo lookup + func = _get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis) - def func2(arr, indexer, out, fill_value=np.nan): - indexer = ensure_int64(indexer) - _take_nd_object( - arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info - ) + if func is None: + + def func(arr, indexer, out, fill_value=np.nan): + indexer = ensure_int64(indexer) + _take_nd_object( + arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info + ) - return func2 + return func def _view_wrapper(f, arr_dtype=None, out_dtype=None, fill_wrap=None): diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index b30751fcefe81..c33e567478475 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -10,6 +10,7 @@ datetime, timedelta, ) +import functools import inspect from typing import ( TYPE_CHECKING, @@ -568,6 +569,35 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan): ValueError If fill_value is a non-scalar and dtype is not object. """ + # TODO(2.0): need to directly use the non-cached version as long as we + # possibly raise a deprecation warning for datetime dtype + if dtype.kind == "M": + return _maybe_promote(dtype, fill_value) + # for performance, we are using a cached version of the actual implementation + # of the function in _maybe_promote. However, this doesn't always work (in case + # of non-hashable arguments), so we fallback to the actual implementation if needed + try: + # error: Argument 3 to "__call__" of "_lru_cache_wrapper" has incompatible type + # "Type[Any]"; expected "Hashable" [arg-type] + return _maybe_promote_cached( + dtype, fill_value, type(fill_value) # type: ignore[arg-type] + ) + except TypeError: + # if fill_value is not hashable (required for caching) + return _maybe_promote(dtype, fill_value) + + +@functools.lru_cache(maxsize=128) +def _maybe_promote_cached(dtype, fill_value, fill_value_type): + # The cached version of _maybe_promote below + # This also use fill_value_type as (unused) argument to use this in the + # cache lookup -> to differentiate 1 and True + return _maybe_promote(dtype, fill_value) + + +def _maybe_promote(dtype: np.dtype, fill_value=np.nan): + # The actual implementation of the function, use `maybe_promote` above for + # a cached version. if not is_scalar(fill_value): # with object dtype there is nothing to promote, and the user can # pass pretty much any weird fill_value they like @@ -618,7 +648,7 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan): "dtype is deprecated. In a future version, this will be cast " "to object dtype. Pass `fill_value=Timestamp(date_obj)` instead.", FutureWarning, - stacklevel=7, + stacklevel=8, ) return dtype, fv elif isinstance(fill_value, str): diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index 998f1ffcf02ee..bd64a79181c51 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -55,6 +55,7 @@ ) import pandas.core.algorithms as algos +from pandas.core.array_algos.take import take_nd from pandas.core.arrays import ( DatetimeArray, ExtensionArray, @@ -1005,7 +1006,7 @@ def unstack(self, unstacker, fill_value) -> ArrayManager: new_arrays = [] for arr in self.arrays: for i in range(unstacker.full_shape[1]): - new_arr = algos.take( + new_arr = take_nd( arr, new_indexer2D[:, i], allow_fill=True, fill_value=fill_value ) new_arrays.append(new_arr) diff --git a/pandas/tests/test_take.py b/pandas/tests/test_take.py index c2668844a5f14..1c7c02ea41f2c 100644 --- a/pandas/tests/test_take.py +++ b/pandas/tests/test_take.py @@ -417,6 +417,18 @@ def test_take_axis_1(self): with pytest.raises(IndexError, match="indices are out-of-bounds"): algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0) + def test_take_non_hashable_fill_value(self): + arr = np.array([1, 2, 3]) + indexer = np.array([1, -1]) + with pytest.raises(ValueError, match="fill_value must be a scalar"): + algos.take(arr, indexer, allow_fill=True, fill_value=[1]) + + # with object dtype it is allowed + arr = np.array([1, 2, 3], dtype=object) + result = algos.take(arr, indexer, allow_fill=True, fill_value=[1]) + expected = np.array([2, [1]], dtype=object) + tm.assert_numpy_array_equal(result, expected) + class TestExtensionTake: # The take method found in pd.api.extensions