Skip to content

Commit 1f16762

Browse files
authored
TYP: Misc changes for pandas-stubs; use Protocol to avoid str in Sequence (#55263)
* TYP: misc changes for pandas-stubs test * re-write changes from 47233 with SequenceNotStr * pyupgrade
1 parent 89bd569 commit 1f16762

File tree

10 files changed

+73
-36
lines changed

10 files changed

+73
-36
lines changed

pandas/_typing.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Type as type_t,
2525
TypeVar,
2626
Union,
27+
overload,
2728
)
2829

2930
import numpy as np
@@ -85,6 +86,8 @@
8586
# Name "npt._ArrayLikeInt_co" is not defined [name-defined]
8687
NumpySorter = Optional[npt._ArrayLikeInt_co] # type: ignore[name-defined]
8788

89+
from typing import SupportsIndex
90+
8891
if sys.version_info >= (3, 10):
8992
from typing import TypeGuard # pyright: ignore[reportUnusedImport]
9093
else:
@@ -109,18 +112,48 @@
109112

110113
# list-like
111114

112-
# Cannot use `Sequence` because a string is a sequence, and we don't want to
113-
# accept that. Could refine if https://github.com/python/typing/issues/256 is
114-
# resolved to differentiate between Sequence[str] and str
115-
ListLike = Union[AnyArrayLike, list, tuple, range]
115+
# from https://github.com/hauntsaninja/useful_types
116+
# includes Sequence-like objects but excludes str and bytes
117+
_T_co = TypeVar("_T_co", covariant=True)
118+
119+
120+
class SequenceNotStr(Protocol[_T_co]):
121+
@overload
122+
def __getitem__(self, index: SupportsIndex, /) -> _T_co:
123+
...
124+
125+
@overload
126+
def __getitem__(self, index: slice, /) -> Sequence[_T_co]:
127+
...
128+
129+
def __contains__(self, value: object, /) -> bool:
130+
...
131+
132+
def __len__(self) -> int:
133+
...
134+
135+
def __iter__(self) -> Iterator[_T_co]:
136+
...
137+
138+
def index(self, value: Any, /, start: int = 0, stop: int = ...) -> int:
139+
...
140+
141+
def count(self, value: Any, /) -> int:
142+
...
143+
144+
def __reversed__(self) -> Iterator[_T_co]:
145+
...
146+
147+
148+
ListLike = Union[AnyArrayLike, SequenceNotStr, range]
116149

117150
# scalars
118151

119152
PythonScalar = Union[str, float, bool]
120153
DatetimeLikeScalar = Union["Period", "Timestamp", "Timedelta"]
121154
PandasScalar = Union["Period", "Timestamp", "Timedelta", "Interval"]
122155
Scalar = Union[PythonScalar, PandasScalar, np.datetime64, np.timedelta64, date]
123-
IntStrT = TypeVar("IntStrT", int, str)
156+
IntStrT = TypeVar("IntStrT", bound=Union[int, str])
124157

125158

126159
# timestamp and timedelta convertible types

pandas/core/frame.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@
240240
Renamer,
241241
Scalar,
242242
Self,
243+
SequenceNotStr,
243244
SortKind,
244245
StorageOptions,
245246
Suffixes,
@@ -1187,7 +1188,7 @@ def to_string(
11871188
buf: None = ...,
11881189
columns: Axes | None = ...,
11891190
col_space: int | list[int] | dict[Hashable, int] | None = ...,
1190-
header: bool | list[str] = ...,
1191+
header: bool | SequenceNotStr[str] = ...,
11911192
index: bool = ...,
11921193
na_rep: str = ...,
11931194
formatters: fmt.FormattersType | None = ...,
@@ -1212,7 +1213,7 @@ def to_string(
12121213
buf: FilePath | WriteBuffer[str],
12131214
columns: Axes | None = ...,
12141215
col_space: int | list[int] | dict[Hashable, int] | None = ...,
1215-
header: bool | list[str] = ...,
1216+
header: bool | SequenceNotStr[str] = ...,
12161217
index: bool = ...,
12171218
na_rep: str = ...,
12181219
formatters: fmt.FormattersType | None = ...,
@@ -1250,7 +1251,7 @@ def to_string(
12501251
buf: FilePath | WriteBuffer[str] | None = None,
12511252
columns: Axes | None = None,
12521253
col_space: int | list[int] | dict[Hashable, int] | None = None,
1253-
header: bool | list[str] = True,
1254+
header: bool | SequenceNotStr[str] = True,
12541255
index: bool = True,
12551256
na_rep: str = "NaN",
12561257
formatters: fmt.FormattersType | None = None,
@@ -10563,9 +10564,9 @@ def merge(
1056310564
self,
1056410565
right: DataFrame | Series,
1056510566
how: MergeHow = "inner",
10566-
on: IndexLabel | None = None,
10567-
left_on: IndexLabel | None = None,
10568-
right_on: IndexLabel | None = None,
10567+
on: IndexLabel | AnyArrayLike | None = None,
10568+
left_on: IndexLabel | AnyArrayLike | None = None,
10569+
right_on: IndexLabel | AnyArrayLike | None = None,
1056910570
left_index: bool = False,
1057010571
right_index: bool = False,
1057110572
sort: bool = False,

pandas/core/generic.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
Renamer,
7373
Scalar,
7474
Self,
75+
SequenceNotStr,
7576
SortKind,
7677
StorageOptions,
7778
Suffixes,
@@ -3273,7 +3274,7 @@ def to_latex(
32733274
self,
32743275
buf: None = ...,
32753276
columns: Sequence[Hashable] | None = ...,
3276-
header: bool_t | list[str] = ...,
3277+
header: bool_t | SequenceNotStr[str] = ...,
32773278
index: bool_t = ...,
32783279
na_rep: str = ...,
32793280
formatters: FormattersType | None = ...,
@@ -3300,7 +3301,7 @@ def to_latex(
33003301
self,
33013302
buf: FilePath | WriteBuffer[str],
33023303
columns: Sequence[Hashable] | None = ...,
3303-
header: bool_t | list[str] = ...,
3304+
header: bool_t | SequenceNotStr[str] = ...,
33043305
index: bool_t = ...,
33053306
na_rep: str = ...,
33063307
formatters: FormattersType | None = ...,
@@ -3330,7 +3331,7 @@ def to_latex(
33303331
self,
33313332
buf: FilePath | WriteBuffer[str] | None = None,
33323333
columns: Sequence[Hashable] | None = None,
3333-
header: bool_t | list[str] = True,
3334+
header: bool_t | SequenceNotStr[str] = True,
33343335
index: bool_t = True,
33353336
na_rep: str = "NaN",
33363337
formatters: FormattersType | None = None,

pandas/core/methods/describe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def describe_timestamp_as_categorical_1d(
301301
names = ["count", "unique"]
302302
objcounts = data.value_counts()
303303
count_unique = len(objcounts[objcounts != 0])
304-
result = [data.count(), count_unique]
304+
result: list[float | Timestamp] = [data.count(), count_unique]
305305
dtype = None
306306
if count_unique > 0:
307307
top, freq = objcounts.index[0], objcounts.iloc[0]

pandas/core/resample.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1541,7 +1541,7 @@ def count(self):
15411541

15421542
return result
15431543

1544-
def quantile(self, q: float | AnyArrayLike = 0.5, **kwargs):
1544+
def quantile(self, q: float | list[float] | AnyArrayLike = 0.5, **kwargs):
15451545
"""
15461546
Return value at the given quantile.
15471547

pandas/core/reshape/merge.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def merge(
138138
left: DataFrame | Series,
139139
right: DataFrame | Series,
140140
how: MergeHow = "inner",
141-
on: IndexLabel | None = None,
142-
left_on: IndexLabel | None = None,
143-
right_on: IndexLabel | None = None,
141+
on: IndexLabel | AnyArrayLike | None = None,
142+
left_on: IndexLabel | AnyArrayLike | None = None,
143+
right_on: IndexLabel | AnyArrayLike | None = None,
144144
left_index: bool = False,
145145
right_index: bool = False,
146146
sort: bool = False,
@@ -187,9 +187,9 @@ def merge(
187187
def _cross_merge(
188188
left: DataFrame,
189189
right: DataFrame,
190-
on: IndexLabel | None = None,
191-
left_on: IndexLabel | None = None,
192-
right_on: IndexLabel | None = None,
190+
on: IndexLabel | AnyArrayLike | None = None,
191+
left_on: IndexLabel | AnyArrayLike | None = None,
192+
right_on: IndexLabel | AnyArrayLike | None = None,
193193
left_index: bool = False,
194194
right_index: bool = False,
195195
sort: bool = False,
@@ -239,7 +239,9 @@ def _cross_merge(
239239
return res
240240

241241

242-
def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces):
242+
def _groupby_and_merge(
243+
by, left: DataFrame | Series, right: DataFrame | Series, merge_pieces
244+
):
243245
"""
244246
groupby & merge; we are always performing a left-by type operation
245247
@@ -255,7 +257,7 @@ def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces):
255257
by = [by]
256258

257259
lby = left.groupby(by, sort=False)
258-
rby: groupby.DataFrameGroupBy | None = None
260+
rby: groupby.DataFrameGroupBy | groupby.SeriesGroupBy | None = None
259261

260262
# if we can groupby the rhs
261263
# then we can get vastly better perf
@@ -295,8 +297,8 @@ def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces):
295297

296298

297299
def merge_ordered(
298-
left: DataFrame,
299-
right: DataFrame,
300+
left: DataFrame | Series,
301+
right: DataFrame | Series,
300302
on: IndexLabel | None = None,
301303
left_on: IndexLabel | None = None,
302304
right_on: IndexLabel | None = None,
@@ -737,9 +739,9 @@ def __init__(
737739
left: DataFrame | Series,
738740
right: DataFrame | Series,
739741
how: MergeHow | Literal["asof"] = "inner",
740-
on: IndexLabel | None = None,
741-
left_on: IndexLabel | None = None,
742-
right_on: IndexLabel | None = None,
742+
on: IndexLabel | AnyArrayLike | None = None,
743+
left_on: IndexLabel | AnyArrayLike | None = None,
744+
right_on: IndexLabel | AnyArrayLike | None = None,
743745
left_index: bool = False,
744746
right_index: bool = False,
745747
sort: bool = True,

pandas/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2141,7 +2141,7 @@ def groupby(
21412141
# Statistics, overridden ndarray methods
21422142

21432143
# TODO: integrate bottleneck
2144-
def count(self):
2144+
def count(self) -> int:
21452145
"""
21462146
Return number of non-NA/null observations in the Series.
21472147

pandas/io/formats/csvs.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222

2323
from pandas._libs import writers as libwriters
24+
from pandas._typing import SequenceNotStr
2425
from pandas.util._decorators import cache_readonly
2526

2627
from pandas.core.dtypes.generic import (
@@ -109,7 +110,7 @@ def decimal(self) -> str:
109110
return self.fmt.decimal
110111

111112
@property
112-
def header(self) -> bool | list[str]:
113+
def header(self) -> bool | SequenceNotStr[str]:
113114
return self.fmt.header
114115

115116
@property
@@ -213,7 +214,7 @@ def _need_to_save_header(self) -> bool:
213214
return bool(self._has_aliases or self.header)
214215

215216
@property
216-
def write_cols(self) -> Sequence[Hashable]:
217+
def write_cols(self) -> SequenceNotStr[Hashable]:
217218
if self._has_aliases:
218219
assert not isinstance(self.header, bool)
219220
if len(self.header) != len(self.cols):
@@ -224,7 +225,7 @@ def write_cols(self) -> Sequence[Hashable]:
224225
else:
225226
# self.cols is an ndarray derived from Index._format_native_types,
226227
# so its entries are strings, i.e. hashable
227-
return cast(Sequence[Hashable], self.cols)
228+
return cast(SequenceNotStr[Hashable], self.cols)
228229

229230
@property
230231
def encoded_labels(self) -> list[Hashable]:

pandas/io/formats/format.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
FloatFormatType,
106106
FormattersType,
107107
IndexLabel,
108+
SequenceNotStr,
108109
StorageOptions,
109110
WriteBuffer,
110111
)
@@ -566,7 +567,7 @@ def __init__(
566567
frame: DataFrame,
567568
columns: Axes | None = None,
568569
col_space: ColspaceArgType | None = None,
569-
header: bool | list[str] = True,
570+
header: bool | SequenceNotStr[str] = True,
570571
index: bool = True,
571572
na_rep: str = "NaN",
572573
formatters: FormattersType | None = None,

pandas/tests/io/test_sql.py

-2
Original file line numberDiff line numberDiff line change
@@ -3161,8 +3161,6 @@ def dtype_backend_data() -> DataFrame:
31613161
@pytest.fixture
31623162
def dtype_backend_expected():
31633163
def func(storage, dtype_backend, conn_name):
3164-
string_array: StringArray | ArrowStringArray
3165-
string_array_na: StringArray | ArrowStringArray
31663164
if storage == "python":
31673165
string_array = StringArray(np.array(["a", "b", "c"], dtype=np.object_))
31683166
string_array_na = StringArray(np.array(["a", "b", pd.NA], dtype=np.object_))

0 commit comments

Comments
 (0)