Skip to content

Commit 1222b45

Browse files
twoertweincbpygit
authored andcommitted
TYP: more misc annotations (pandas-dev#56675)
* TYP: more misc annotations * workaround for Generic * fix drop
1 parent b1c3c4f commit 1222b45

File tree

16 files changed

+90
-72
lines changed

16 files changed

+90
-72
lines changed

pandas/_config/config.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from collections.abc import (
7676
Generator,
7777
Iterable,
78+
Sequence,
7879
)
7980

8081

@@ -853,7 +854,7 @@ def inner(x) -> None:
853854
return inner
854855

855856

856-
def is_instance_factory(_type) -> Callable[[Any], None]:
857+
def is_instance_factory(_type: type | tuple[type, ...]) -> Callable[[Any], None]:
857858
"""
858859
859860
Parameters
@@ -866,8 +867,7 @@ def is_instance_factory(_type) -> Callable[[Any], None]:
866867
ValueError if x is not an instance of `_type`
867868
868869
"""
869-
if isinstance(_type, (tuple, list)):
870-
_type = tuple(_type)
870+
if isinstance(_type, tuple):
871871
type_repr = "|".join(map(str, _type))
872872
else:
873873
type_repr = f"'{_type}'"
@@ -879,7 +879,7 @@ def inner(x) -> None:
879879
return inner
880880

881881

882-
def is_one_of_factory(legal_values) -> Callable[[Any], None]:
882+
def is_one_of_factory(legal_values: Sequence) -> Callable[[Any], None]:
883883
callables = [c for c in legal_values if callable(c)]
884884
legal_values = [c for c in legal_values if not callable(c)]
885885

@@ -930,7 +930,7 @@ def is_nonnegative_int(value: object) -> None:
930930
is_text = is_instance_factory((str, bytes))
931931

932932

933-
def is_callable(obj) -> bool:
933+
def is_callable(obj: object) -> bool:
934934
"""
935935
936936
Parameters

pandas/core/arrays/period.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def asfreq(self, freq=None, how: str = "E") -> Self:
759759
# ------------------------------------------------------------------
760760
# Rendering Methods
761761

762-
def _formatter(self, boxed: bool = False):
762+
def _formatter(self, boxed: bool = False) -> Callable[[object], str]:
763763
if boxed:
764764
return str
765765
return "'{}'".format

pandas/core/arrays/timedeltas.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ def sequence_to_td64ns(
10801080
return data, inferred_freq
10811081

10821082

1083-
def _ints_to_td64ns(data, unit: str = "ns"):
1083+
def _ints_to_td64ns(data, unit: str = "ns") -> tuple[np.ndarray, bool]:
10841084
"""
10851085
Convert an ndarray with integer-dtype to timedelta64[ns] dtype, treating
10861086
the integers as multiples of the given timedelta unit.
@@ -1120,7 +1120,9 @@ def _ints_to_td64ns(data, unit: str = "ns"):
11201120
return data, copy_made
11211121

11221122

1123-
def _objects_to_td64ns(data, unit=None, errors: DateTimeErrorChoices = "raise"):
1123+
def _objects_to_td64ns(
1124+
data, unit=None, errors: DateTimeErrorChoices = "raise"
1125+
) -> np.ndarray:
11241126
"""
11251127
Convert a object-dtyped or string-dtyped array into an
11261128
timedelta64[ns]-dtyped array.

pandas/core/config_init.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def is_terminal() -> bool:
329329
"min_rows",
330330
10,
331331
pc_min_rows_doc,
332-
validator=is_instance_factory([type(None), int]),
332+
validator=is_instance_factory((type(None), int)),
333333
)
334334
cf.register_option("max_categories", 8, pc_max_categories_doc, validator=is_int)
335335

@@ -369,7 +369,7 @@ def is_terminal() -> bool:
369369
cf.register_option("chop_threshold", None, pc_chop_threshold_doc)
370370
cf.register_option("max_seq_items", 100, pc_max_seq_items)
371371
cf.register_option(
372-
"width", 80, pc_width_doc, validator=is_instance_factory([type(None), int])
372+
"width", 80, pc_width_doc, validator=is_instance_factory((type(None), int))
373373
)
374374
cf.register_option(
375375
"memory_usage",
@@ -850,14 +850,14 @@ def register_converter_cb(key) -> None:
850850
"format.thousands",
851851
None,
852852
styler_thousands,
853-
validator=is_instance_factory([type(None), str]),
853+
validator=is_instance_factory((type(None), str)),
854854
)
855855

856856
cf.register_option(
857857
"format.na_rep",
858858
None,
859859
styler_na_rep,
860-
validator=is_instance_factory([type(None), str]),
860+
validator=is_instance_factory((type(None), str)),
861861
)
862862

863863
cf.register_option(
@@ -867,11 +867,15 @@ def register_converter_cb(key) -> None:
867867
validator=is_one_of_factory([None, "html", "latex", "latex-math"]),
868868
)
869869

870+
# error: Argument 1 to "is_instance_factory" has incompatible type "tuple[
871+
# ..., <typing special form>, ...]"; expected "type | tuple[type, ...]"
870872
cf.register_option(
871873
"format.formatter",
872874
None,
873875
styler_formatter,
874-
validator=is_instance_factory([type(None), dict, Callable, str]),
876+
validator=is_instance_factory(
877+
(type(None), dict, Callable, str) # type: ignore[arg-type]
878+
),
875879
)
876880

877881
cf.register_option("html.mathjax", True, styler_mathjax, validator=is_bool)
@@ -898,7 +902,7 @@ def register_converter_cb(key) -> None:
898902
"latex.environment",
899903
None,
900904
styler_environment,
901-
validator=is_instance_factory([type(None), str]),
905+
validator=is_instance_factory((type(None), str)),
902906
)
903907

904908

pandas/core/dtypes/missing.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -557,11 +557,13 @@ def _array_equivalent_float(left: np.ndarray, right: np.ndarray) -> bool:
557557
return bool(((left == right) | (np.isnan(left) & np.isnan(right))).all())
558558

559559

560-
def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray):
560+
def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray) -> bool:
561561
return np.array_equal(left.view("i8"), right.view("i8"))
562562

563563

564-
def _array_equivalent_object(left: np.ndarray, right: np.ndarray, strict_nan: bool):
564+
def _array_equivalent_object(
565+
left: np.ndarray, right: np.ndarray, strict_nan: bool
566+
) -> bool:
565567
left = ensure_object(left)
566568
right = ensure_object(right)
567569

pandas/core/frame.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@
233233
IndexLabel,
234234
JoinValidate,
235235
Level,
236+
ListLike,
236237
MergeHow,
237238
MergeValidate,
238239
MutableMappingT,
@@ -5349,11 +5350,11 @@ def reindex(
53495350
@overload
53505351
def drop(
53515352
self,
5352-
labels: IndexLabel = ...,
5353+
labels: IndexLabel | ListLike = ...,
53535354
*,
53545355
axis: Axis = ...,
5355-
index: IndexLabel = ...,
5356-
columns: IndexLabel = ...,
5356+
index: IndexLabel | ListLike = ...,
5357+
columns: IndexLabel | ListLike = ...,
53575358
level: Level = ...,
53585359
inplace: Literal[True],
53595360
errors: IgnoreRaise = ...,
@@ -5363,11 +5364,11 @@ def drop(
53635364
@overload
53645365
def drop(
53655366
self,
5366-
labels: IndexLabel = ...,
5367+
labels: IndexLabel | ListLike = ...,
53675368
*,
53685369
axis: Axis = ...,
5369-
index: IndexLabel = ...,
5370-
columns: IndexLabel = ...,
5370+
index: IndexLabel | ListLike = ...,
5371+
columns: IndexLabel | ListLike = ...,
53715372
level: Level = ...,
53725373
inplace: Literal[False] = ...,
53735374
errors: IgnoreRaise = ...,
@@ -5377,11 +5378,11 @@ def drop(
53775378
@overload
53785379
def drop(
53795380
self,
5380-
labels: IndexLabel = ...,
5381+
labels: IndexLabel | ListLike = ...,
53815382
*,
53825383
axis: Axis = ...,
5383-
index: IndexLabel = ...,
5384-
columns: IndexLabel = ...,
5384+
index: IndexLabel | ListLike = ...,
5385+
columns: IndexLabel | ListLike = ...,
53855386
level: Level = ...,
53865387
inplace: bool = ...,
53875388
errors: IgnoreRaise = ...,
@@ -5390,11 +5391,11 @@ def drop(
53905391

53915392
def drop(
53925393
self,
5393-
labels: IndexLabel | None = None,
5394+
labels: IndexLabel | ListLike = None,
53945395
*,
53955396
axis: Axis = 0,
5396-
index: IndexLabel | None = None,
5397-
columns: IndexLabel | None = None,
5397+
index: IndexLabel | ListLike = None,
5398+
columns: IndexLabel | ListLike = None,
53985399
level: Level | None = None,
53995400
inplace: bool = False,
54005401
errors: IgnoreRaise = "raise",

pandas/core/generic.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
IntervalClosedType,
6666
JSONSerializable,
6767
Level,
68+
ListLike,
6869
Manager,
6970
NaPosition,
7071
NDFrameT,
@@ -4709,11 +4710,11 @@ def reindex_like(
47094710
@overload
47104711
def drop(
47114712
self,
4712-
labels: IndexLabel = ...,
4713+
labels: IndexLabel | ListLike = ...,
47134714
*,
47144715
axis: Axis = ...,
4715-
index: IndexLabel = ...,
4716-
columns: IndexLabel = ...,
4716+
index: IndexLabel | ListLike = ...,
4717+
columns: IndexLabel | ListLike = ...,
47174718
level: Level | None = ...,
47184719
inplace: Literal[True],
47194720
errors: IgnoreRaise = ...,
@@ -4723,11 +4724,11 @@ def drop(
47234724
@overload
47244725
def drop(
47254726
self,
4726-
labels: IndexLabel = ...,
4727+
labels: IndexLabel | ListLike = ...,
47274728
*,
47284729
axis: Axis = ...,
4729-
index: IndexLabel = ...,
4730-
columns: IndexLabel = ...,
4730+
index: IndexLabel | ListLike = ...,
4731+
columns: IndexLabel | ListLike = ...,
47314732
level: Level | None = ...,
47324733
inplace: Literal[False] = ...,
47334734
errors: IgnoreRaise = ...,
@@ -4737,11 +4738,11 @@ def drop(
47374738
@overload
47384739
def drop(
47394740
self,
4740-
labels: IndexLabel = ...,
4741+
labels: IndexLabel | ListLike = ...,
47414742
*,
47424743
axis: Axis = ...,
4743-
index: IndexLabel = ...,
4744-
columns: IndexLabel = ...,
4744+
index: IndexLabel | ListLike = ...,
4745+
columns: IndexLabel | ListLike = ...,
47454746
level: Level | None = ...,
47464747
inplace: bool_t = ...,
47474748
errors: IgnoreRaise = ...,
@@ -4750,11 +4751,11 @@ def drop(
47504751

47514752
def drop(
47524753
self,
4753-
labels: IndexLabel | None = None,
4754+
labels: IndexLabel | ListLike = None,
47544755
*,
47554756
axis: Axis = 0,
4756-
index: IndexLabel | None = None,
4757-
columns: IndexLabel | None = None,
4757+
index: IndexLabel | ListLike = None,
4758+
columns: IndexLabel | ListLike = None,
47584759
level: Level | None = None,
47594760
inplace: bool_t = False,
47604761
errors: IgnoreRaise = "raise",

pandas/core/methods/selectn.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from typing import (
1212
TYPE_CHECKING,
13+
Generic,
1314
cast,
1415
final,
1516
)
@@ -32,32 +33,41 @@
3233
from pandas._typing import (
3334
DtypeObj,
3435
IndexLabel,
36+
NDFrameT,
3537
)
3638

3739
from pandas import (
3840
DataFrame,
3941
Series,
4042
)
43+
else:
44+
# Generic[...] requires a non-str, provide it with a plain TypeVar at
45+
# runtime to avoid circular imports
46+
from pandas._typing import T
4147

48+
NDFrameT = T
49+
DataFrame = T
50+
Series = T
4251

43-
class SelectN:
44-
def __init__(self, obj, n: int, keep: str) -> None:
52+
53+
class SelectN(Generic[NDFrameT]):
54+
def __init__(self, obj: NDFrameT, n: int, keep: str) -> None:
4555
self.obj = obj
4656
self.n = n
4757
self.keep = keep
4858

4959
if self.keep not in ("first", "last", "all"):
5060
raise ValueError('keep must be either "first", "last" or "all"')
5161

52-
def compute(self, method: str) -> DataFrame | Series:
62+
def compute(self, method: str) -> NDFrameT:
5363
raise NotImplementedError
5464

5565
@final
56-
def nlargest(self):
66+
def nlargest(self) -> NDFrameT:
5767
return self.compute("nlargest")
5868

5969
@final
60-
def nsmallest(self):
70+
def nsmallest(self) -> NDFrameT:
6171
return self.compute("nsmallest")
6272

6373
@final
@@ -72,7 +82,7 @@ def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
7282
return needs_i8_conversion(dtype)
7383

7484

75-
class SelectNSeries(SelectN):
85+
class SelectNSeries(SelectN[Series]):
7686
"""
7787
Implement n largest/smallest for Series
7888
@@ -163,7 +173,7 @@ def compute(self, method: str) -> Series:
163173
return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
164174

165175

166-
class SelectNFrame(SelectN):
176+
class SelectNFrame(SelectN[DataFrame]):
167177
"""
168178
Implement n largest/smallest for DataFrame
169179

pandas/core/missing.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -969,14 +969,11 @@ def _pad_2d(
969969
values: np.ndarray,
970970
limit: int | None = None,
971971
mask: npt.NDArray[np.bool_] | None = None,
972-
):
972+
) -> tuple[np.ndarray, npt.NDArray[np.bool_]]:
973973
mask = _fillna_prep(values, mask)
974974

975975
if values.size:
976976
algos.pad_2d_inplace(values, mask, limit=limit)
977-
else:
978-
# for test coverage
979-
pass
980977
return values, mask
981978

982979

0 commit comments

Comments
 (0)