Skip to content

Commit 6bb1215

Browse files
authored
Simplify DataFrame.__getitem__ (#771)
* simplify DataFrame.__getitem__ * restrict to Iterable with Hashable/Scalar elements * unrelated: widen type of DataFrame.__iter__, can be any column type * Scalar | Hashable -> Hashable
1 parent ab5c643 commit 6bb1215

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

pandas-stubs/core/frame.pyi

+4-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import (
22
Callable,
3-
Generator,
43
Hashable,
54
Iterable,
65
Iterator,
@@ -14,7 +13,6 @@ from typing import (
1413
Any,
1514
ClassVar,
1615
Literal,
17-
TypeVar,
1816
overload,
1917
)
2018

@@ -119,8 +117,6 @@ from pandas._typing import (
119117
ValidationOptions,
120118
WriteBuffer,
121119
XMLParsers,
122-
np_ndarray_bool,
123-
np_ndarray_str,
124120
npt,
125121
num,
126122
)
@@ -130,7 +126,6 @@ from pandas.plotting import PlotAccessor
130126

131127
_str = str
132128
_bool = bool
133-
_ScalarOrTupleT = TypeVar("_ScalarOrTupleT", bound=Scalar | tuple[Hashable, ...])
134129

135130
class _iLocIndexerFrame(_iLocIndexer):
136131
@overload
@@ -553,20 +548,11 @@ class DataFrame(NDFrame, OpsMixin):
553548
def T(self) -> DataFrame: ...
554549
def __getattr__(self, name: str) -> Series: ...
555550
@overload
556-
def __getitem__( # type: ignore[misc]
557-
self,
558-
key: Series
559-
| DataFrame
560-
| Index
561-
| np_ndarray_str
562-
| np_ndarray_bool
563-
| list[_ScalarOrTupleT]
564-
| Generator[_ScalarOrTupleT, None, None],
565-
) -> DataFrame: ...
551+
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[misc]
566552
@overload
567-
def __getitem__(self, key: slice) -> DataFrame: ...
553+
def __getitem__(self, key: Iterable[Hashable] | slice) -> DataFrame: ...
568554
@overload
569-
def __getitem__(self, key: Scalar | Hashable) -> Series: ...
555+
def __getitem__(self, key: Hashable) -> Series: ...
570556
def isetitem(
571557
self, loc: int | Sequence[int], value: Scalar | ArrayLike | list[Any]
572558
) -> None: ...
@@ -1477,7 +1463,7 @@ class DataFrame(NDFrame, OpsMixin):
14771463
Name: _str
14781464
#
14791465
# dunder methods
1480-
def __iter__(self) -> Iterator[float | _str]: ...
1466+
def __iter__(self) -> Iterator[Hashable]: ...
14811467
# properties
14821468
@property
14831469
def at(self): ... # Not sure what to do with this yet; look at source

tests/test_frame.py

+7
Original file line numberDiff line numberDiff line change
@@ -2811,3 +2811,10 @@ def test_groupby_fillna_inplace() -> None:
28112811
def test_getitem_generator() -> None:
28122812
# GH 685
28132813
check(assert_type(DF[(f"col{i+1}" for i in range(2))], pd.DataFrame), pd.DataFrame)
2814+
2815+
2816+
def test_getitem_dict_keys() -> None:
2817+
# GH 770
2818+
some_columns = {"a": [1], "b": [2]}
2819+
df = pd.DataFrame.from_dict(some_columns)
2820+
check(assert_type(df[some_columns.keys()], pd.DataFrame), pd.DataFrame)

0 commit comments

Comments
 (0)