@@ -67,6 +67,9 @@ def take_nd(
67
67
This dispatches to ``take`` defined on ExtensionArrays. It does not
68
68
currently dispatch to ``SparseArray.take`` for sparse ``arr``.
69
69
70
+ Note: this function assumes that the indexer is a valid(ated) indexer with
71
+ no out of bound indices.
72
+
70
73
Parameters
71
74
----------
72
75
arr : np.ndarray or ExtensionArray
@@ -113,8 +116,13 @@ def _take_nd_ndarray(
113
116
allow_fill : bool ,
114
117
) -> np .ndarray :
115
118
119
+ if indexer is None :
120
+ indexer = np .arange (arr .shape [axis ], dtype = np .int64 )
121
+ dtype , fill_value = arr .dtype , arr .dtype .type ()
122
+ else :
123
+ indexer = ensure_int64 (indexer , copy = False )
116
124
indexer , dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
117
- arr , indexer , axis , out , fill_value , allow_fill
125
+ arr , indexer , out , fill_value , allow_fill
118
126
)
119
127
120
128
flip_order = False
@@ -159,13 +167,17 @@ def take_1d(
159
167
allow_fill : bool = True ,
160
168
) -> ArrayLike :
161
169
"""
162
- Specialized version for 1D arrays. Differences compared to take_nd:
170
+ Specialized version for 1D arrays. Differences compared to ` take_nd` :
163
171
164
- - Assumes input (arr, indexer) has already been converted to numpy array / EA
172
+ - Assumes input array has already been converted to numpy array / EA
173
+ - Assumes indexer is already guaranteed to be int64 dtype ndarray
165
174
- Only works for 1D arrays
166
175
167
176
To ensure the lowest possible overhead.
168
177
178
+ Note: similarly to `take_nd`, this function assumes that the indexer is
179
+ a valid(ated) indexer with no out of bound indices.
180
+
169
181
TODO(ArrayManager): mainly useful for ArrayManager, otherwise can potentially
170
182
be removed again if we don't end up with ArrayManager.
171
183
"""
@@ -180,8 +192,11 @@ def take_1d(
180
192
allow_fill = allow_fill ,
181
193
)
182
194
195
+ if not allow_fill :
196
+ return arr .take (indexer )
197
+
183
198
indexer , dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
184
- arr , indexer , 0 , None , fill_value , allow_fill
199
+ arr , indexer , None , fill_value , allow_fill
185
200
)
186
201
187
202
# at this point, it's guaranteed that dtype can hold both the arr values
@@ -502,43 +517,32 @@ def _take_2d_multi_object(
502
517
503
518
def _take_preprocess_indexer_and_fill_value (
504
519
arr : np .ndarray ,
505
- indexer : Optional [np .ndarray ],
506
- axis : int ,
520
+ indexer : np .ndarray ,
507
521
out : Optional [np .ndarray ],
508
522
fill_value ,
509
523
allow_fill : bool ,
510
524
):
511
525
mask_info = None
512
526
513
- if indexer is None :
514
- indexer = np .arange (arr .shape [axis ], dtype = np .int64 )
527
+ if not allow_fill :
515
528
dtype , fill_value = arr .dtype , arr .dtype .type ()
529
+ mask_info = None , False
516
530
else :
517
- indexer = ensure_int64 (indexer , copy = False )
518
- if not allow_fill :
519
- dtype , fill_value = arr .dtype , arr .dtype .type ()
520
- mask_info = None , False
521
- else :
522
- # check for promotion based on types only (do this first because
523
- # it's faster than computing a mask)
524
- dtype , fill_value = maybe_promote (arr .dtype , fill_value )
525
- if dtype != arr .dtype and (out is None or out .dtype != dtype ):
526
- # check if promotion is actually required based on indexer
527
- mask = indexer == - 1
528
- # error: Item "bool" of "Union[Any, bool]" has no attribute "any"
529
- # [union-attr]
530
- needs_masking = mask .any () # type: ignore[union-attr]
531
- # error: Incompatible types in assignment (expression has type
532
- # "Tuple[Union[Any, bool], Any]", variable has type
533
- # "Optional[Tuple[None, bool]]")
534
- mask_info = mask , needs_masking # type: ignore[assignment]
535
- if needs_masking :
536
- if out is not None and out .dtype != dtype :
537
- raise TypeError ("Incompatible type for fill_value" )
538
- else :
539
- # if not, then depromote, set fill_value to dummy
540
- # (it won't be used but we don't want the cython code
541
- # to crash when trying to cast it to dtype)
542
- dtype , fill_value = arr .dtype , arr .dtype .type ()
531
+ # check for promotion based on types only (do this first because
532
+ # it's faster than computing a mask)
533
+ dtype , fill_value = maybe_promote (arr .dtype , fill_value )
534
+ if dtype != arr .dtype and (out is None or out .dtype != dtype ):
535
+ # check if promotion is actually required based on indexer
536
+ mask = indexer == - 1
537
+ needs_masking = mask .any ()
538
+ mask_info = mask , needs_masking
539
+ if needs_masking :
540
+ if out is not None and out .dtype != dtype :
541
+ raise TypeError ("Incompatible type for fill_value" )
542
+ else :
543
+ # if not, then depromote, set fill_value to dummy
544
+ # (it won't be used but we don't want the cython code
545
+ # to crash when trying to cast it to dtype)
546
+ dtype , fill_value = arr .dtype , arr .dtype .type ()
543
547
544
548
return indexer , dtype , fill_value , mask_info
0 commit comments