From 319ca5860bb5791e4052a731829ec69300a62c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Thu, 31 Aug 2023 23:51:44 -0400 Subject: [PATCH 1/4] simplify DataFrame.__getitem__ --- pandas-stubs/core/frame.pyi | 20 +++----------------- tests/test_frame.py | 7 +++++++ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index eff25f3c1..83bc1deba 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1,6 +1,5 @@ from collections.abc import ( Callable, - Generator, Hashable, Iterable, Iterator, @@ -14,7 +13,6 @@ from typing import ( Any, ClassVar, Literal, - TypeVar, overload, ) @@ -119,8 +117,6 @@ from pandas._typing import ( ValidationOptions, WriteBuffer, XMLParsers, - np_ndarray_bool, - np_ndarray_str, npt, num, ) @@ -130,7 +126,6 @@ from pandas.plotting import PlotAccessor _str = str _bool = bool -_ScalarOrTupleT = TypeVar("_ScalarOrTupleT", bound=Scalar | tuple[Hashable, ...]) class _iLocIndexerFrame(_iLocIndexer): @overload @@ -553,20 +548,11 @@ class DataFrame(NDFrame, OpsMixin): def T(self) -> DataFrame: ... def __getattr__(self, name: str) -> Series: ... @overload - def __getitem__( # type: ignore[misc] - self, - key: Series - | DataFrame - | Index - | np_ndarray_str - | np_ndarray_bool - | list[_ScalarOrTupleT] - | Generator[_ScalarOrTupleT, None, None], - ) -> DataFrame: ... + def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[misc] @overload - def __getitem__(self, key: slice) -> DataFrame: ... + def __getitem__(self, key: Iterable | slice) -> DataFrame: ... @overload - def __getitem__(self, key: Scalar | Hashable) -> Series: ... + def __getitem__(self, key: Hashable) -> Series: ... def isetitem( self, loc: int | Sequence[int], value: Scalar | ArrayLike | list[Any] ) -> None: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 383f4ecc6..e1e1f7e27 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2811,3 +2811,10 @@ def test_groupby_fillna_inplace() -> None: def test_getitem_generator() -> None: # GH 685 check(assert_type(DF[(f"col{i+1}" for i in range(2))], pd.DataFrame), pd.DataFrame) + + +def test_getitem_dict_keys() -> None: + # GH 770 + some_columns = {"a": [1], "b": [2]} + df = pd.DataFrame.from_dict(some_columns) + check(assert_type(df[some_columns.keys()], pd.DataFrame), pd.DataFrame) From 2181f297a9373a37e7eb0478b2a144495ca7c788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Thu, 31 Aug 2023 23:58:26 -0400 Subject: [PATCH 2/4] restrict to Iterable with Hashable/Scalar elements --- pandas-stubs/core/frame.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 83bc1deba..6de500b28 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -550,7 +550,7 @@ class DataFrame(NDFrame, OpsMixin): @overload def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[misc] @overload - def __getitem__(self, key: Iterable | slice) -> DataFrame: ... + def __getitem__(self, key: Iterable[Scalar | Hashable] | slice) -> DataFrame: ... @overload def __getitem__(self, key: Hashable) -> Series: ... def isetitem( From 596cd823195cdcd2d1a9d05197332c4daebd3f28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Fri, 1 Sep 2023 00:09:00 -0400 Subject: [PATCH 3/4] unrelated: widen type of DataFrame.__iter__, can be any column type --- pandas-stubs/core/frame.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 6de500b28..5d74bb07d 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1463,7 +1463,7 @@ class DataFrame(NDFrame, OpsMixin): Name: _str # # dunder methods - def __iter__(self) -> Iterator[float | _str]: ... + def __iter__(self) -> Iterator[Scalar | Hashable]: ... # properties @property def at(self): ... # Not sure what to do with this yet; look at source From 6e78f8c22bc9675bbfa6a7c642c6093285ba51ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Fri, 1 Sep 2023 09:48:35 -0400 Subject: [PATCH 4/4] Scalar | Hashable -> Hashable --- pandas-stubs/core/frame.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 5d74bb07d..d2394cdf5 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -550,7 +550,7 @@ class DataFrame(NDFrame, OpsMixin): @overload def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[misc] @overload - def __getitem__(self, key: Iterable[Scalar | Hashable] | slice) -> DataFrame: ... + def __getitem__(self, key: Iterable[Hashable] | slice) -> DataFrame: ... @overload def __getitem__(self, key: Hashable) -> Series: ... def isetitem( @@ -1463,7 +1463,7 @@ class DataFrame(NDFrame, OpsMixin): Name: _str # # dunder methods - def __iter__(self) -> Iterator[Scalar | Hashable]: ... + def __iter__(self) -> Iterator[Hashable]: ... # properties @property def at(self): ... # Not sure what to do with this yet; look at source