Skip to content

Commit ac2cdf5

Browse files
jbrockmendelukarroum
authored andcommitted
REF: simplify Index.take, MultiIndex.take (pandas-dev#37551)
1 parent a20d436 commit ac2cdf5

File tree

2 files changed

+31
-52
lines changed

2 files changed

+31
-52
lines changed

pandas/core/indexes/base.py

+22-30
Original file line numberDiff line numberDiff line change
@@ -740,44 +740,36 @@ def take(self, indices, axis=0, allow_fill=True, fill_value=None, **kwargs):
740740
if kwargs:
741741
nv.validate_take(tuple(), kwargs)
742742
indices = ensure_platform_int(indices)
743-
if self._can_hold_na:
744-
taken = self._assert_take_fillable(
745-
self._values,
746-
indices,
747-
allow_fill=allow_fill,
748-
fill_value=fill_value,
749-
na_value=self._na_value,
750-
)
751-
else:
752-
if allow_fill and fill_value is not None:
753-
cls_name = type(self).__name__
754-
raise ValueError(
755-
f"Unable to fill values because {cls_name} cannot contain NA"
756-
)
757-
taken = self._values.take(indices)
743+
allow_fill = self._maybe_disallow_fill(allow_fill, fill_value, indices)
744+
745+
# Note: we discard fill_value and use self._na_value, only relevant
746+
# in the case where allow_fill is True and fill_value is not None
747+
taken = algos.take(
748+
self._values, indices, allow_fill=allow_fill, fill_value=self._na_value
749+
)
758750
return self._shallow_copy(taken)
759751

760-
def _assert_take_fillable(
761-
self, values, indices, allow_fill=True, fill_value=None, na_value=np.nan
762-
):
752+
def _maybe_disallow_fill(self, allow_fill: bool, fill_value, indices) -> bool:
763753
"""
764-
Internal method to handle NA filling of take.
754+
We only use pandas-style take when allow_fill is True _and_
755+
fill_value is not None.
765756
"""
766-
indices = ensure_platform_int(indices)
767-
768-
# only fill if we are passing a non-None fill_value
769757
if allow_fill and fill_value is not None:
770-
if (indices < -1).any():
758+
# only fill if we are passing a non-None fill_value
759+
if self._can_hold_na:
760+
if (indices < -1).any():
761+
raise ValueError(
762+
"When allow_fill=True and fill_value is not None, "
763+
"all indices must be >= -1"
764+
)
765+
else:
766+
cls_name = type(self).__name__
771767
raise ValueError(
772-
"When allow_fill=True and fill_value is not None, "
773-
"all indices must be >= -1"
768+
f"Unable to fill values because {cls_name} cannot contain NA"
774769
)
775-
taken = algos.take(
776-
values, indices, allow_fill=allow_fill, fill_value=na_value
777-
)
778770
else:
779-
taken = algos.take(values, indices, allow_fill=False, fill_value=na_value)
780-
return taken
771+
allow_fill = False
772+
return allow_fill
781773

782774
_index_shared_docs[
783775
"repeat"

pandas/core/indexes/multi.py

+9-22
Original file line numberDiff line numberDiff line change
@@ -2018,29 +2018,13 @@ def __getitem__(self, key):
20182018
def take(self, indices, axis=0, allow_fill=True, fill_value=None, **kwargs):
20192019
nv.validate_take(tuple(), kwargs)
20202020
indices = ensure_platform_int(indices)
2021-
taken = self._assert_take_fillable(
2022-
self.codes,
2023-
indices,
2024-
allow_fill=allow_fill,
2025-
fill_value=fill_value,
2026-
na_value=-1,
2027-
)
2028-
return MultiIndex(
2029-
levels=self.levels, codes=taken, names=self.names, verify_integrity=False
2030-
)
20312021

2032-
def _assert_take_fillable(
2033-
self, values, indices, allow_fill=True, fill_value=None, na_value=None
2034-
):
2035-
""" Internal method to handle NA filling of take """
20362022
# only fill if we are passing a non-None fill_value
2037-
if allow_fill and fill_value is not None:
2038-
if (indices < -1).any():
2039-
msg = (
2040-
"When allow_fill=True and fill_value is not None, "
2041-
"all indices must be >= -1"
2042-
)
2043-
raise ValueError(msg)
2023+
allow_fill = self._maybe_disallow_fill(allow_fill, fill_value, indices)
2024+
2025+
na_value = -1
2026+
2027+
if allow_fill:
20442028
taken = [lab.take(indices) for lab in self.codes]
20452029
mask = indices == -1
20462030
if mask.any():
@@ -2052,7 +2036,10 @@ def _assert_take_fillable(
20522036
taken = masked
20532037
else:
20542038
taken = [lab.take(indices) for lab in self.codes]
2055-
return taken
2039+
2040+
return MultiIndex(
2041+
levels=self.levels, codes=taken, names=self.names, verify_integrity=False
2042+
)
20562043

20572044
def append(self, other):
20582045
"""

0 commit comments

Comments
 (0)