Skip to content

Commit a3cabb3

Browse files
authored
Index slice (#311)
* Fix IndexSlice * Move IndexSliceTuple over to indexing.pyi * fix for python 3.8 * add tests. Move IndexSlice to typing so flake8 is happy * fix test to surround tuple with quotes
1 parent 7c38e4d commit a3cabb3

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

pandas-stubs/core/indexing.pyi

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import (
2-
Generic,
32
TypeVar,
43
Union,
54
)
@@ -9,14 +8,19 @@ from pandas.core.indexes.api import Index
98

109
from pandas._libs.indexing import _NDFrameIndexerBase
1110
from pandas._typing import (
11+
MaskType,
1212
Scalar,
13-
StrLike,
13+
ScalarT,
1414
)
1515

16-
_IndexSliceT = TypeVar("_IndexSliceT", bound=Union[StrLike, Scalar, slice])
16+
_IndexSliceTuple = Union[
17+
slice, tuple[Union[Index, MaskType, Scalar, list[ScalarT], slice], ...]
18+
]
1719

18-
class _IndexSlice(Generic[_IndexSliceT]):
19-
def __getitem__(self, arg) -> tuple[_IndexSliceT, ...]: ...
20+
_IndexSliceTupleT = TypeVar("_IndexSliceTupleT", bound=_IndexSliceTuple)
21+
22+
class _IndexSlice:
23+
def __getitem__(self, arg: _IndexSliceTupleT) -> _IndexSliceTupleT: ...
2024

2125
IndexSlice: _IndexSlice
2226

tests/test_frame.py

+22
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,28 @@ def test_indexslice_setitem():
12781278
df.loc[pd.IndexSlice[2, :], "z"] = [200, 300]
12791279

12801280

1281+
def test_indexslice_getitem():
1282+
# GH 300
1283+
df = (
1284+
pd.DataFrame({"x": [1, 2, 2, 3, 4], "y": [10, 20, 30, 40, 10]})
1285+
.assign(z=lambda df: df.x * df.y)
1286+
.set_index(["x", "y"])
1287+
)
1288+
ind = pd.Index([2, 3])
1289+
check(assert_type(pd.IndexSlice[ind, :], "tuple[pd.Index, slice]"), tuple)
1290+
check(assert_type(df.loc[pd.IndexSlice[ind, :]], pd.DataFrame), pd.DataFrame)
1291+
check(assert_type(df.loc[pd.IndexSlice[1:2]], pd.DataFrame), pd.DataFrame)
1292+
check(
1293+
assert_type(df.loc[pd.IndexSlice[:, df["z"] > 40], :], pd.DataFrame),
1294+
pd.DataFrame,
1295+
)
1296+
check(assert_type(df.loc[pd.IndexSlice[2, 30], "z"], Scalar), np.int64)
1297+
check(
1298+
assert_type(df.loc[pd.IndexSlice[[2, 4], [20, 40]], :], pd.DataFrame),
1299+
pd.DataFrame,
1300+
)
1301+
1302+
12811303
def test_compute_values():
12821304
df = pd.DataFrame({"x": [1, 2, 3, 4]})
12831305
s: pd.Series = pd.Series([10, 20, 30, 40])

0 commit comments

Comments
 (0)