Skip to content

Commit 2b490df

Browse files
REF: separate out indexer/mask preprocessing code in algorithms.take_nd (#39728)
1 parent 879d2fb commit 2b490df

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

pandas/core/algorithms.py

+37-27
Original file line numberDiff line numberDiff line change
@@ -1661,6 +1661,40 @@ def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None)
16611661
return result
16621662

16631663

1664+
def _take_preprocess_indexer_and_fill_value(
1665+
arr, indexer, axis, out, fill_value, allow_fill
1666+
):
1667+
mask_info = None
1668+
1669+
if indexer is None:
1670+
indexer = np.arange(arr.shape[axis], dtype=np.int64)
1671+
dtype, fill_value = arr.dtype, arr.dtype.type()
1672+
else:
1673+
indexer = ensure_int64(indexer, copy=False)
1674+
if not allow_fill:
1675+
dtype, fill_value = arr.dtype, arr.dtype.type()
1676+
mask_info = None, False
1677+
else:
1678+
# check for promotion based on types only (do this first because
1679+
# it's faster than computing a mask)
1680+
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
1681+
if dtype != arr.dtype and (out is None or out.dtype != dtype):
1682+
# check if promotion is actually required based on indexer
1683+
mask = indexer == -1
1684+
needs_masking = mask.any()
1685+
mask_info = mask, needs_masking
1686+
if needs_masking:
1687+
if out is not None and out.dtype != dtype:
1688+
raise TypeError("Incompatible type for fill_value")
1689+
else:
1690+
# if not, then depromote, set fill_value to dummy
1691+
# (it won't be used but we don't want the cython code
1692+
# to crash when trying to cast it to dtype)
1693+
dtype, fill_value = arr.dtype, arr.dtype.type()
1694+
1695+
return indexer, dtype, fill_value, mask_info
1696+
1697+
16641698
def take_nd(
16651699
arr,
16661700
indexer,
@@ -1700,8 +1734,6 @@ def take_nd(
17001734
subarray : array-like
17011735
May be the same type as the input, or cast to an ndarray.
17021736
"""
1703-
mask_info = None
1704-
17051737
if fill_value is lib.no_default:
17061738
fill_value = na_value_for_dtype(arr.dtype, compat=False)
17071739

@@ -1712,31 +1744,9 @@ def take_nd(
17121744
arr = extract_array(arr)
17131745
arr = np.asarray(arr)
17141746

1715-
if indexer is None:
1716-
indexer = np.arange(arr.shape[axis], dtype=np.int64)
1717-
dtype, fill_value = arr.dtype, arr.dtype.type()
1718-
else:
1719-
indexer = ensure_int64(indexer, copy=False)
1720-
if not allow_fill:
1721-
dtype, fill_value = arr.dtype, arr.dtype.type()
1722-
mask_info = None, False
1723-
else:
1724-
# check for promotion based on types only (do this first because
1725-
# it's faster than computing a mask)
1726-
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
1727-
if dtype != arr.dtype and (out is None or out.dtype != dtype):
1728-
# check if promotion is actually required based on indexer
1729-
mask = indexer == -1
1730-
needs_masking = mask.any()
1731-
mask_info = mask, needs_masking
1732-
if needs_masking:
1733-
if out is not None and out.dtype != dtype:
1734-
raise TypeError("Incompatible type for fill_value")
1735-
else:
1736-
# if not, then depromote, set fill_value to dummy
1737-
# (it won't be used but we don't want the cython code
1738-
# to crash when trying to cast it to dtype)
1739-
dtype, fill_value = arr.dtype, arr.dtype.type()
1747+
indexer, dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value(
1748+
arr, indexer, axis, out, fill_value, allow_fill
1749+
)
17401750

17411751
flip_order = False
17421752
if arr.ndim == 2 and arr.flags.f_contiguous:

0 commit comments

Comments
 (0)