Skip to content

Commit d2f3264

Browse files
authored
Allow subtypes of List[Scalar] by introducing ScalarArg (#124)
* Allow subtypes of List[Scalar] by introducing ScalarArg * add check on result * black on test_frame.py * ScalarArg -> ScalarT
1 parent 1d96b4d commit d2f3264

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

pandas-stubs/_typing.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ Scalar = Union[
122122
Timestamp,
123123
Timedelta,
124124
]
125+
ScalarT = TypeVar("ScalarT", bound=Scalar)
125126
# Refine the definitions below in 3.9 to use the specialized type.
126127
np_ndarray_int8 = npt.NDArray[np.int8]
127128
np_ndarray_int16 = npt.NDArray[np.int16]

pandas-stubs/core/frame.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ from pandas._typing import (
6565
MaskType,
6666
Renamer,
6767
Scalar,
68+
ScalarT,
6869
SeriesAxisType,
6970
StrLike,
7071
T as TType,
@@ -368,7 +369,7 @@ class DataFrame(NDFrame, OpsMixin):
368369
Series[_bool],
369370
DataFrame,
370371
List[_str],
371-
List[Scalar],
372+
List[ScalarT],
372373
Index,
373374
np_ndarray_str,
374375
np_ndarray_bool,

tests/test_frame.py

+8
Original file line numberDiff line numberDiff line change
@@ -1169,3 +1169,11 @@ def test_func(h: Hashable):
11691169
test_func(pd.DataFrame()) # type: ignore[arg-type]
11701170
test_func(pd.Series([], dtype=object)) # type: ignore[arg-type]
11711171
test_func(pd.Index([])) # type: ignore[arg-type]
1172+
1173+
1174+
def test_columns_mixlist() -> None:
1175+
# GH 97
1176+
df = pd.DataFrame({"a": [1, 2, 3], 1: [3, 4, 5]})
1177+
key: List[Union[int, str]]
1178+
key = [1]
1179+
check(assert_type(df[key], pd.DataFrame), pd.DataFrame)

0 commit comments

Comments
 (0)