From b2a6f29da1e1cb172cebeab16f0acf4c620991b2 Mon Sep 17 00:00:00 2001 From: ramvikrams Date: Thu, 16 Mar 2023 10:58:00 -0700 Subject: [PATCH 1/2] fixed .loc --- pandas-stubs/core/frame.pyi | 9 ++++++++- tests/test_frame.py | 8 +++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index d5f138e11..8a9bafebc 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -161,7 +161,12 @@ class _LocIndexerFrame(_LocIndexer): | Callable[[DataFrame], IndexType | MaskType | list[HashableT]] | list[HashableT] | tuple[ - IndexType | MaskType | list[HashableT] | Hashable, + IndexType + | MaskType + | list[HashableT] + | slice + | _IndexSliceTuple + | Callable, list[HashableT] | slice | Series[bool] | Callable, ], ) -> DataFrame: ... @@ -188,6 +193,8 @@ class _LocIndexerFrame(_LocIndexer): | None, ) -> Series: ... @overload + def __getitem__(self, idx: tuple[ScalarT, slice]) -> Series | DataFrame: ... + @overload def __setitem__( self, idx: MaskType | StrLike | _IndexSliceTuple | list[ScalarT], diff --git a/tests/test_frame.py b/tests/test_frame.py index 064f208bf..55def8a68 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2172,7 +2172,7 @@ def test_loc_slice() -> None: {"x": [1, 2, 3, 4]}, index=pd.MultiIndex.from_product([[1, 2], ["a", "b"]], names=["num", "let"]), ) - check(assert_type(df1.loc[1, :], pd.DataFrame), pd.DataFrame) + check(assert_type(df1.loc[1, :], Union[pd.Series, pd.DataFrame]), pd.DataFrame) def test_where() -> None: @@ -2521,3 +2521,9 @@ def test_align() -> None: aligned_df0, aligned_df1 = df0.align(df1) check(assert_type(aligned_df0, pd.DataFrame), pd.DataFrame) check(assert_type(aligned_df1, pd.DataFrame), pd.DataFrame) + + +def test_loc_new() -> None: + df1 = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40]) + df2 = df1.loc[10, :] + check(assert_type(df2, Union[pd.Series, pd.DataFrame]), pd.Series) From 12a4e602e458e448513b27bce9237e47e35ef993 Mon Sep 17 00:00:00 2001 From: ramvikrams Date: Fri, 17 Mar 2023 00:32:00 +0530 Subject: [PATCH 2/2] updated test name and func argument --- pandas-stubs/core/frame.pyi | 2 +- tests/test_frame.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 8a9bafebc..f8913c5cc 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -193,7 +193,7 @@ class _LocIndexerFrame(_LocIndexer): | None, ) -> Series: ... @overload - def __getitem__(self, idx: tuple[ScalarT, slice]) -> Series | DataFrame: ... + def __getitem__(self, idx: tuple[Scalar, slice]) -> Series | DataFrame: ... @overload def __setitem__( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index 55def8a68..c6e57f58a 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2523,7 +2523,7 @@ def test_align() -> None: check(assert_type(aligned_df1, pd.DataFrame), pd.DataFrame) -def test_loc_new() -> None: +def test_loc_returns_series() -> None: df1 = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40]) df2 = df1.loc[10, :] check(assert_type(df2, Union[pd.Series, pd.DataFrame]), pd.Series)