Skip to content

Commit 8ae8173

Browse files
authored
REF: stricter typing in Manager.take (#51478)
1 parent 7f022b3 commit 8ae8173

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

pandas/core/generic.py

+6
Original file line numberDiff line numberDiff line change
@@ -3895,6 +3895,7 @@ class max_speed
38953895

38963896
return self._take(indices, axis)
38973897

3898+
@final
38983899
def _take(
38993900
self: NDFrameT,
39003901
indices,
@@ -3915,6 +3916,11 @@ def _take(
39153916
and is_range_indexer(indices, len(self))
39163917
):
39173918
return self.copy(deep=None)
3919+
else:
3920+
# We can get here with a slice via DataFrame.__geittem__
3921+
indices = np.arange(
3922+
indices.start, indices.stop, indices.step, dtype=np.intp
3923+
)
39183924

39193925
new_data = self._mgr.take(
39203926
indices,

pandas/core/internals/array_manager.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -630,21 +630,18 @@ def _reindex_indexer(
630630

631631
def take(
632632
self: T,
633-
indexer,
633+
indexer: npt.NDArray[np.intp],
634634
axis: AxisInt = 1,
635635
verify: bool = True,
636636
convert_indices: bool = True,
637637
) -> T:
638638
"""
639639
Take items along any axis.
640640
"""
641-
axis = self._normalize_axis(axis)
641+
assert isinstance(indexer, np.ndarray), type(indexer)
642+
assert indexer.dtype == np.intp, indexer.dtype
642643

643-
indexer = (
644-
np.arange(indexer.start, indexer.stop, indexer.step, dtype="int64")
645-
if isinstance(indexer, slice)
646-
else np.asanyarray(indexer, dtype="int64")
647-
)
644+
axis = self._normalize_axis(axis)
648645

649646
if not indexer.ndim == 1:
650647
raise ValueError("indexer should be 1-dimensional")

pandas/core/internals/managers.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -907,15 +907,15 @@ def _make_na_block(
907907

908908
def take(
909909
self: T,
910-
indexer,
910+
indexer: npt.NDArray[np.intp],
911911
axis: AxisInt = 1,
912912
verify: bool = True,
913913
convert_indices: bool = True,
914914
) -> T:
915915
"""
916916
Take items along any axis.
917917
918-
indexer : np.ndarray or slice
918+
indexer : np.ndarray[np.intp]
919919
axis : int, default 1
920920
verify : bool, default True
921921
Check that all entries are between 0 and len(self) - 1, inclusive.
@@ -927,12 +927,8 @@ def take(
927927
-------
928928
BlockManager
929929
"""
930-
# We have 6 tests that get here with a slice
931-
indexer = (
932-
np.arange(indexer.start, indexer.stop, indexer.step, dtype=np.intp)
933-
if isinstance(indexer, slice)
934-
else np.asanyarray(indexer, dtype=np.intp)
935-
)
930+
assert isinstance(indexer, np.ndarray), type(indexer)
931+
assert indexer.dtype == np.intp, indexer.dtype
936932

937933
n = self.shape[axis]
938934
if convert_indices:

pandas/tests/internals/test_internals.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1003,13 +1003,15 @@ def assert_take_ok(mgr, axis, indexer):
10031003

10041004
for ax in range(mgr.ndim):
10051005
# take/fancy indexer
1006-
assert_take_ok(mgr, ax, indexer=[])
1007-
assert_take_ok(mgr, ax, indexer=[0, 0, 0])
1008-
assert_take_ok(mgr, ax, indexer=list(range(mgr.shape[ax])))
1006+
assert_take_ok(mgr, ax, indexer=np.array([], dtype=np.intp))
1007+
assert_take_ok(mgr, ax, indexer=np.array([0, 0, 0], dtype=np.intp))
1008+
assert_take_ok(
1009+
mgr, ax, indexer=np.array(list(range(mgr.shape[ax])), dtype=np.intp)
1010+
)
10091011

10101012
if mgr.shape[ax] >= 3:
1011-
assert_take_ok(mgr, ax, indexer=[0, 1, 2])
1012-
assert_take_ok(mgr, ax, indexer=[-1, -2, -3])
1013+
assert_take_ok(mgr, ax, indexer=np.array([0, 1, 2], dtype=np.intp))
1014+
assert_take_ok(mgr, ax, indexer=np.array([-1, -2, -3], dtype=np.intp))
10131015

10141016
@pytest.mark.parametrize("mgr", MANAGERS)
10151017
@pytest.mark.parametrize("fill_value", [None, np.nan, 100.0])

0 commit comments

Comments
 (0)