@@ -1661,6 +1661,40 @@ def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None)
1661
1661
return result
1662
1662
1663
1663
1664
+ def _take_preprocess_indexer_and_fill_value (
1665
+ arr , indexer , axis , out , fill_value , allow_fill
1666
+ ):
1667
+ mask_info = None
1668
+
1669
+ if indexer is None :
1670
+ indexer = np .arange (arr .shape [axis ], dtype = np .int64 )
1671
+ dtype , fill_value = arr .dtype , arr .dtype .type ()
1672
+ else :
1673
+ indexer = ensure_int64 (indexer , copy = False )
1674
+ if not allow_fill :
1675
+ dtype , fill_value = arr .dtype , arr .dtype .type ()
1676
+ mask_info = None , False
1677
+ else :
1678
+ # check for promotion based on types only (do this first because
1679
+ # it's faster than computing a mask)
1680
+ dtype , fill_value = maybe_promote (arr .dtype , fill_value )
1681
+ if dtype != arr .dtype and (out is None or out .dtype != dtype ):
1682
+ # check if promotion is actually required based on indexer
1683
+ mask = indexer == - 1
1684
+ needs_masking = mask .any ()
1685
+ mask_info = mask , needs_masking
1686
+ if needs_masking :
1687
+ if out is not None and out .dtype != dtype :
1688
+ raise TypeError ("Incompatible type for fill_value" )
1689
+ else :
1690
+ # if not, then depromote, set fill_value to dummy
1691
+ # (it won't be used but we don't want the cython code
1692
+ # to crash when trying to cast it to dtype)
1693
+ dtype , fill_value = arr .dtype , arr .dtype .type ()
1694
+
1695
+ return indexer , dtype , fill_value , mask_info
1696
+
1697
+
1664
1698
def take_nd (
1665
1699
arr ,
1666
1700
indexer ,
@@ -1700,8 +1734,6 @@ def take_nd(
1700
1734
subarray : array-like
1701
1735
May be the same type as the input, or cast to an ndarray.
1702
1736
"""
1703
- mask_info = None
1704
-
1705
1737
if fill_value is lib .no_default :
1706
1738
fill_value = na_value_for_dtype (arr .dtype , compat = False )
1707
1739
@@ -1712,31 +1744,9 @@ def take_nd(
1712
1744
arr = extract_array (arr )
1713
1745
arr = np .asarray (arr )
1714
1746
1715
- if indexer is None :
1716
- indexer = np .arange (arr .shape [axis ], dtype = np .int64 )
1717
- dtype , fill_value = arr .dtype , arr .dtype .type ()
1718
- else :
1719
- indexer = ensure_int64 (indexer , copy = False )
1720
- if not allow_fill :
1721
- dtype , fill_value = arr .dtype , arr .dtype .type ()
1722
- mask_info = None , False
1723
- else :
1724
- # check for promotion based on types only (do this first because
1725
- # it's faster than computing a mask)
1726
- dtype , fill_value = maybe_promote (arr .dtype , fill_value )
1727
- if dtype != arr .dtype and (out is None or out .dtype != dtype ):
1728
- # check if promotion is actually required based on indexer
1729
- mask = indexer == - 1
1730
- needs_masking = mask .any ()
1731
- mask_info = mask , needs_masking
1732
- if needs_masking :
1733
- if out is not None and out .dtype != dtype :
1734
- raise TypeError ("Incompatible type for fill_value" )
1735
- else :
1736
- # if not, then depromote, set fill_value to dummy
1737
- # (it won't be used but we don't want the cython code
1738
- # to crash when trying to cast it to dtype)
1739
- dtype , fill_value = arr .dtype , arr .dtype .type ()
1747
+ indexer , dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
1748
+ arr , indexer , axis , out , fill_value , allow_fill
1749
+ )
1740
1750
1741
1751
flip_order = False
1742
1752
if arr .ndim == 2 and arr .flags .f_contiguous :
0 commit comments