From 0cee941e3ec0754500e41cfb67c6cff1c80f7071 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 9 Jul 2022 22:13:44 -0400 Subject: [PATCH 1/4] Allow subtypes of List[Scalar] by introducing ScalarArg --- pandas-stubs/_typing.pyi | 1 + pandas-stubs/core/frame.pyi | 3 ++- tests/test_frame.py | 7 +++++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 8eafcda24..ca178efad 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -122,6 +122,7 @@ Scalar = Union[ Timestamp, Timedelta, ] +ScalarArg = TypeVar("ScalarArg", bound=Scalar) # Refine the definitions below in 3.9 to use the specialized type. np_ndarray_int8 = npt.NDArray[np.int8] np_ndarray_int16 = npt.NDArray[np.int16] diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 39f3f2384..f8169c449 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -65,6 +65,7 @@ from pandas._typing import ( MaskType, Renamer, Scalar, + ScalarArg, SeriesAxisType, StrLike, T as TType, @@ -368,7 +369,7 @@ class DataFrame(NDFrame, OpsMixin): Series[_bool], DataFrame, List[_str], - List[Scalar], + List[ScalarArg], Index, np_ndarray_str, np_ndarray_bool, diff --git a/tests/test_frame.py b/tests/test_frame.py index a50950778..855f4e73a 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1169,3 +1169,10 @@ def test_func(h: Hashable): test_func(pd.DataFrame()) # type: ignore[arg-type] test_func(pd.Series([], dtype=object)) # type: ignore[arg-type] test_func(pd.Index([])) # type: ignore[arg-type] + +def test_columns_mixlist() -> None: + # GH 97 + df = pd.DataFrame({"a":[1,2,3],1:[3,4,5]}) + key: List[Union[int, str]] + key = [1] + df[key] \ No newline at end of file From 76ffbddbbfaf1fefa06834cc26c30f3cfe602d3e Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 9 Jul 2022 22:15:34 -0400 Subject: [PATCH 2/4] add check on result --- tests/test_frame.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_frame.py b/tests/test_frame.py index 855f4e73a..45753d434 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1175,4 +1175,5 @@ def test_columns_mixlist() -> None: df = pd.DataFrame({"a":[1,2,3],1:[3,4,5]}) key: List[Union[int, str]] key = [1] - df[key] \ No newline at end of file + check(assert_type(df[key], pd.DataFrame), pd.DataFrame) + \ No newline at end of file From 46e7637c85eef419bdf847e54463255ec41661e4 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 9 Jul 2022 22:21:07 -0400 Subject: [PATCH 3/4] black on test_frame.py --- tests/test_frame.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_frame.py b/tests/test_frame.py index 45753d434..38760e986 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1170,10 +1170,10 @@ def test_func(h: Hashable): test_func(pd.Series([], dtype=object)) # type: ignore[arg-type] test_func(pd.Index([])) # type: ignore[arg-type] + def test_columns_mixlist() -> None: # GH 97 - df = pd.DataFrame({"a":[1,2,3],1:[3,4,5]}) + df = pd.DataFrame({"a": [1, 2, 3], 1: [3, 4, 5]}) key: List[Union[int, str]] key = [1] check(assert_type(df[key], pd.DataFrame), pd.DataFrame) - \ No newline at end of file From 5dfdcf1fdf85ea2cf6e6be82fa2aff612e2f4960 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 10 Jul 2022 12:38:58 -0400 Subject: [PATCH 4/4] ScalarArg -> ScalarT --- pandas-stubs/_typing.pyi | 2 +- pandas-stubs/core/frame.pyi | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index ca178efad..0fc1d674f 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -122,7 +122,7 @@ Scalar = Union[ Timestamp, Timedelta, ] -ScalarArg = TypeVar("ScalarArg", bound=Scalar) +ScalarT = TypeVar("ScalarT", bound=Scalar) # Refine the definitions below in 3.9 to use the specialized type. np_ndarray_int8 = npt.NDArray[np.int8] np_ndarray_int16 = npt.NDArray[np.int16] diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index f8169c449..610f3f420 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -65,7 +65,7 @@ from pandas._typing import ( MaskType, Renamer, Scalar, - ScalarArg, + ScalarT, SeriesAxisType, StrLike, T as TType, @@ -369,7 +369,7 @@ class DataFrame(NDFrame, OpsMixin): Series[_bool], DataFrame, List[_str], - List[ScalarArg], + List[ScalarT], Index, np_ndarray_str, np_ndarray_bool,