From 1dd248bb8fafb47f8ff46ac11a1b5cf4401e1631 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Fri, 10 Feb 2023 14:05:16 -0500 Subject: [PATCH 1/2] remove types from Index.__iter__()` --- pandas-stubs/core/indexes/base.pyi | 3 +-- tests/test_frame.py | 6 ++++++ tests/test_indexes.py | 9 +++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 98a6b72bd..d24c9f161 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -33,7 +33,6 @@ from pandas._typing import ( DtypeObj, FillnaOptions, HashableT, - IndexIterScalar, IndexT, Label, Level, @@ -223,7 +222,7 @@ class Index(IndexOpsMixin, PandasObject): def shape(self) -> tuple[int, ...]: ... # Extra methods from old stubs def __eq__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override] - def __iter__(self) -> Iterator[IndexIterScalar | tuple[Hashable, ...]]: ... + def __iter__(self) -> Iterator: ... def __ne__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override] def __le__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override] def __ge__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override] diff --git a/tests/test_frame.py b/tests/test_frame.py index 2346492c2..fd5297169 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2403,3 +2403,9 @@ def get_NDArray(df: pd.DataFrame, key: npt.NDArray[np.uint64]) -> pd.DataFrame: a: npt.NDArray[np.uint64] = np.array([10, 30], dtype="uint64") check(assert_type(get_NDArray(df, a), pd.DataFrame), pd.DataFrame) + + +def test_in_columns() -> None: + df = pd.DataFrame(np.random.random((3, 4)), columns=["cat", "dog", "rat", "pig"]) + cols = [c for c in df.columns if "at" in c] + check(assert_type(cols, list), list, str) diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 4f700cd1c..c4c7c48e6 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -3,8 +3,6 @@ import datetime as dt from typing import ( TYPE_CHECKING, - Hashable, - List, Tuple, Union, ) @@ -29,7 +27,6 @@ if TYPE_CHECKING: from pandas.core.indexes.numeric import NumericIndex - from pandas._typing import IndexIterScalar else: if not PD_LTE_15: from pandas import Index as NumericIndex @@ -108,7 +105,7 @@ def test_column_sequence() -> None: df = pd.DataFrame([1, 2, 3]) col_list = list(df.columns) check( - assert_type(col_list, List[Union["IndexIterScalar", Tuple[Hashable, ...]]]), + assert_type(col_list, list), list, int, ) @@ -684,14 +681,14 @@ def test_sorted_and_list() -> None: check( assert_type( sorted(i1), - List[Union["IndexIterScalar", Tuple[Hashable, ...]]], + list, ), list, ) check( assert_type( list(i1), - List[Union["IndexIterScalar", Tuple[Hashable, ...]]], + list, ), list, ) From 50f4f9f1d2072675318db09ac7b2fcc2e71953fe Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Fri, 10 Feb 2023 15:51:14 -0500 Subject: [PATCH 2/2] add some tests --- tests/test_frame.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_frame.py b/tests/test_frame.py index fd5297169..6f12eb2df 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1813,7 +1813,6 @@ def test_frame_index_numpy() -> None: def test_frame_stack() -> None: - multicol2 = pd.MultiIndex.from_tuples([("weight", "kg"), ("height", "m")]) df_multi_level_cols2 = pd.DataFrame( [[1.0, 2.0], [3.0, 4.0]], index=["cat", "dog"], columns=multicol2 @@ -2406,6 +2405,10 @@ def get_NDArray(df: pd.DataFrame, key: npt.NDArray[np.uint64]) -> pd.DataFrame: def test_in_columns() -> None: + # GH 532 (PR) df = pd.DataFrame(np.random.random((3, 4)), columns=["cat", "dog", "rat", "pig"]) cols = [c for c in df.columns if "at" in c] check(assert_type(cols, list), list, str) + check(assert_type(df.loc[:, cols], pd.DataFrame), pd.DataFrame) + check(assert_type(df[cols], pd.DataFrame), pd.DataFrame) + check(assert_type(df.groupby(by=cols).sum(), pd.DataFrame), pd.DataFrame)