Skip to content

Commit e100576

Browse files
Dr-Irvtwoertwein
authored andcommitted
allow callable in .loc (pandas-dev#509)
1 parent bb2905a commit e100576

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

pandas-stubs/core/frame.pyi

+11-2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class _LocIndexerFrame(_LocIndexer):
157157
self,
158158
idx: IndexType
159159
| MaskType
160+
| Callable[[DataFrame], IndexType | MaskType | list[HashableT]]
160161
| list[HashableT]
161162
| tuple[
162163
IndexType | MaskType | list[HashableT] | Hashable,
@@ -167,14 +168,22 @@ class _LocIndexerFrame(_LocIndexer):
167168
def __getitem__(
168169
self,
169170
idx: tuple[
170-
int | StrLike | tuple[Scalar, ...], int | StrLike | tuple[Scalar, ...]
171+
int | StrLike | tuple[Scalar, ...] | Callable[[DataFrame], ScalarT],
172+
int | StrLike | tuple[Scalar, ...],
171173
],
172174
) -> Scalar: ...
173175
@overload
174176
def __getitem__(
175177
self,
176178
idx: ScalarT
177-
| tuple[IndexType | MaskType | _IndexSliceTuple, ScalarT | None]
179+
| Callable[[DataFrame], ScalarT]
180+
| tuple[
181+
IndexType
182+
| MaskType
183+
| _IndexSliceTuple
184+
| Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType],
185+
ScalarT | None,
186+
]
178187
| None,
179188
) -> Series: ...
180189
@overload

tests/test_frame.py

+22
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TypedDict,
2323
TypeVar,
2424
Union,
25+
cast,
2526
)
2627

2728
import numpy as np
@@ -2366,6 +2367,27 @@ def test_frame_dropna_subset() -> None:
23662367
)
23672368

23682369

2370+
def test_loc_callable() -> None:
2371+
# GH 256
2372+
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
2373+
2374+
def select1(df: pd.DataFrame) -> pd.Series:
2375+
return df["x"] > 2.0
2376+
2377+
check(assert_type(df.loc[select1], pd.DataFrame), pd.DataFrame)
2378+
check(assert_type(df.loc[select1, :], pd.DataFrame), pd.DataFrame)
2379+
2380+
def select2(df: pd.DataFrame) -> list[Hashable]:
2381+
return [i for i in df.index if cast(int, i) % 2 == 1]
2382+
2383+
check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series)
2384+
2385+
def select3(df: pd.DataFrame) -> int:
2386+
return 1
2387+
2388+
check(assert_type(df.loc[select3, "x"], Scalar), np.integer)
2389+
2390+
23692391
def test_npint_loc_indexer() -> None:
23702392
# GH 508
23712393

0 commit comments

Comments
 (0)