Skip to content

Commit 90db462

Browse files
authored
remove types from Index.__iter__()` (#532)
* remove types from Index.__iter__()` * add some tests
1 parent 23e3ff2 commit 90db462

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

pandas-stubs/core/indexes/base.pyi

+1-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ from pandas._typing import (
3333
DtypeObj,
3434
FillnaOptions,
3535
HashableT,
36-
IndexIterScalar,
3736
IndexT,
3837
Label,
3938
Level,
@@ -223,7 +222,7 @@ class Index(IndexOpsMixin, PandasObject):
223222
def shape(self) -> tuple[int, ...]: ...
224223
# Extra methods from old stubs
225224
def __eq__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override]
226-
def __iter__(self) -> Iterator[IndexIterScalar | tuple[Hashable, ...]]: ...
225+
def __iter__(self) -> Iterator: ...
227226
def __ne__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override]
228227
def __le__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override]
229228
def __ge__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override]

tests/test_frame.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1813,7 +1813,6 @@ def test_frame_index_numpy() -> None:
18131813

18141814

18151815
def test_frame_stack() -> None:
1816-
18171816
multicol2 = pd.MultiIndex.from_tuples([("weight", "kg"), ("height", "m")])
18181817
df_multi_level_cols2 = pd.DataFrame(
18191818
[[1.0, 2.0], [3.0, 4.0]], index=["cat", "dog"], columns=multicol2
@@ -2403,3 +2402,13 @@ def get_NDArray(df: pd.DataFrame, key: npt.NDArray[np.uint64]) -> pd.DataFrame:
24032402

24042403
a: npt.NDArray[np.uint64] = np.array([10, 30], dtype="uint64")
24052404
check(assert_type(get_NDArray(df, a), pd.DataFrame), pd.DataFrame)
2405+
2406+
2407+
def test_in_columns() -> None:
2408+
# GH 532 (PR)
2409+
df = pd.DataFrame(np.random.random((3, 4)), columns=["cat", "dog", "rat", "pig"])
2410+
cols = [c for c in df.columns if "at" in c]
2411+
check(assert_type(cols, list), list, str)
2412+
check(assert_type(df.loc[:, cols], pd.DataFrame), pd.DataFrame)
2413+
check(assert_type(df[cols], pd.DataFrame), pd.DataFrame)
2414+
check(assert_type(df.groupby(by=cols).sum(), pd.DataFrame), pd.DataFrame)

tests/test_indexes.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import datetime as dt
44
from typing import (
55
TYPE_CHECKING,
6-
Hashable,
7-
List,
86
Tuple,
97
Union,
108
)
@@ -29,7 +27,6 @@
2927
if TYPE_CHECKING:
3028
from pandas.core.indexes.numeric import NumericIndex
3129

32-
from pandas._typing import IndexIterScalar
3330
else:
3431
if not PD_LTE_15:
3532
from pandas import Index as NumericIndex
@@ -108,7 +105,7 @@ def test_column_sequence() -> None:
108105
df = pd.DataFrame([1, 2, 3])
109106
col_list = list(df.columns)
110107
check(
111-
assert_type(col_list, List[Union["IndexIterScalar", Tuple[Hashable, ...]]]),
108+
assert_type(col_list, list),
112109
list,
113110
int,
114111
)
@@ -684,14 +681,14 @@ def test_sorted_and_list() -> None:
684681
check(
685682
assert_type(
686683
sorted(i1),
687-
List[Union["IndexIterScalar", Tuple[Hashable, ...]]],
684+
list,
688685
),
689686
list,
690687
)
691688
check(
692689
assert_type(
693690
list(i1),
694-
List[Union["IndexIterScalar", Tuple[Hashable, ...]]],
691+
list,
695692
),
696693
list,
697694
)

0 commit comments

Comments
 (0)