Skip to content

Commit 4ef638a

Browse files
authored
TYP: try out TypeGuard (#51309)
1 parent a29c206 commit 4ef638a

File tree

19 files changed

+69
-46
lines changed

19 files changed

+69
-46
lines changed

pandas/_libs/lib.pyi

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TODO(npdtypes): Many types specified here can be made more specific/accurate;
22
# the more specific versions are specified in comments
3-
3+
from decimal import Decimal
44
from typing import (
55
Any,
66
Callable,
@@ -13,9 +13,12 @@ from typing import (
1313

1414
import numpy as np
1515

16+
from pandas._libs.interval import Interval
17+
from pandas._libs.tslibs import Period
1618
from pandas._typing import (
1719
ArrayLike,
1820
DtypeObj,
21+
TypeGuard,
1922
npt,
2023
)
2124

@@ -38,13 +41,13 @@ def infer_dtype(value: object, skipna: bool = ...) -> str: ...
3841
def is_iterator(obj: object) -> bool: ...
3942
def is_scalar(val: object) -> bool: ...
4043
def is_list_like(obj: object, allow_sets: bool = ...) -> bool: ...
41-
def is_period(val: object) -> bool: ...
42-
def is_interval(val: object) -> bool: ...
43-
def is_decimal(val: object) -> bool: ...
44-
def is_complex(val: object) -> bool: ...
45-
def is_bool(val: object) -> bool: ...
46-
def is_integer(val: object) -> bool: ...
47-
def is_float(val: object) -> bool: ...
44+
def is_period(val: object) -> TypeGuard[Period]: ...
45+
def is_interval(val: object) -> TypeGuard[Interval]: ...
46+
def is_decimal(val: object) -> TypeGuard[Decimal]: ...
47+
def is_complex(val: object) -> TypeGuard[complex]: ...
48+
def is_bool(val: object) -> TypeGuard[bool | np.bool_]: ...
49+
def is_integer(val: object) -> TypeGuard[int | np.integer]: ...
50+
def is_float(val: object) -> TypeGuard[float]: ...
4851
def is_interval_array(values: np.ndarray) -> bool: ...
4952
def is_datetime64_array(values: np.ndarray) -> bool: ...
5053
def is_timedelta_or_timedelta64_array(values: np.ndarray) -> bool: ...

pandas/_typing.py

+6
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,19 @@
8484
# Name "npt._ArrayLikeInt_co" is not defined [name-defined]
8585
NumpySorter = Optional[npt._ArrayLikeInt_co] # type: ignore[name-defined]
8686

87+
if sys.version_info >= (3, 10):
88+
from typing import TypeGuard
89+
else:
90+
from typing_extensions import TypeGuard # pyright: reportUnusedImport = false
91+
8792
if sys.version_info >= (3, 11):
8893
from typing import Self
8994
else:
9095
from typing_extensions import Self # pyright: reportUnusedImport = false
9196
else:
9297
npt: Any = None
9398
Self: Any = None
99+
TypeGuard: Any = None
94100

95101
HashableT = TypeVar("HashableT", bound=Hashable)
96102

pandas/compat/numpy/function.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
overload,
2626
)
2727

28+
import numpy as np
2829
from numpy import ndarray
2930

3031
from pandas._libs.lib import (
@@ -215,7 +216,7 @@ def validate_clip_with_axis(
215216
)
216217

217218

218-
def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
219+
def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
219220
"""
220221
If this function is called via the 'numpy' library, the third parameter in
221222
its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
@@ -224,6 +225,8 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
224225
if not is_bool(skipna):
225226
args = (skipna,) + args
226227
skipna = True
228+
elif isinstance(skipna, np.bool_):
229+
skipna = bool(skipna)
227230

228231
validate_cum_func(args, kwargs, fname=name)
229232
return skipna

pandas/core/arrays/datetimelike.py

-1
Original file line numberDiff line numberDiff line change
@@ -2172,7 +2172,6 @@ def validate_periods(periods: int | float | None) -> int | None:
21722172
periods = int(periods)
21732173
elif not lib.is_integer(periods):
21742174
raise TypeError(f"periods must be a number, got {periods}")
2175-
periods = cast(int, periods)
21762175
return periods
21772176

21782177

pandas/core/dtypes/cast.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,9 @@ def maybe_box_native(value: Scalar | None | NAType) -> Scalar | None | NAType:
191191
scalar or Series
192192
"""
193193
if is_float(value):
194-
# error: Argument 1 to "float" has incompatible type
195-
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
196-
# expected "Union[SupportsFloat, _SupportsIndex, str]"
197-
value = float(value) # type: ignore[arg-type]
194+
value = float(value)
198195
elif is_integer(value):
199-
# error: Argument 1 to "int" has incompatible type
200-
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
201-
# expected "Union[str, SupportsInt, _SupportsIndex, _SupportsTrunc]"
202-
value = int(value) # type: ignore[arg-type]
196+
value = int(value)
203197
elif is_bool(value):
204198
value = bool(value)
205199
elif isinstance(value, (np.datetime64, np.timedelta64)):

pandas/core/dtypes/inference.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
from collections import abc
66
from numbers import Number
77
import re
8-
from typing import Pattern
8+
from typing import (
9+
TYPE_CHECKING,
10+
Hashable,
11+
Pattern,
12+
)
913

1014
import numpy as np
1115

1216
from pandas._libs import lib
1317

18+
if TYPE_CHECKING:
19+
from pandas._typing import TypeGuard
20+
1421
is_bool = lib.is_bool
1522

1623
is_integer = lib.is_integer
@@ -30,7 +37,7 @@
3037
is_iterator = lib.is_iterator
3138

3239

33-
def is_number(obj) -> bool:
40+
def is_number(obj) -> TypeGuard[Number | np.number]:
3441
"""
3542
Check if the object is a number.
3643
@@ -132,7 +139,7 @@ def is_file_like(obj) -> bool:
132139
return bool(hasattr(obj, "__iter__"))
133140

134141

135-
def is_re(obj) -> bool:
142+
def is_re(obj) -> TypeGuard[Pattern]:
136143
"""
137144
Check if the object is a regex pattern instance.
138145
@@ -325,7 +332,7 @@ def is_named_tuple(obj) -> bool:
325332
return isinstance(obj, abc.Sequence) and hasattr(obj, "_fields")
326333

327334

328-
def is_hashable(obj) -> bool:
335+
def is_hashable(obj) -> TypeGuard[Hashable]:
329336
"""
330337
Return True if hash(obj) will succeed, False otherwise.
331338

pandas/core/frame.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -9516,11 +9516,7 @@ def melt(
95169516
)
95179517
def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame:
95189518
if not lib.is_integer(periods):
9519-
if not (
9520-
is_float(periods)
9521-
# error: "int" has no attribute "is_integer"
9522-
and periods.is_integer() # type: ignore[attr-defined]
9523-
):
9519+
if not (is_float(periods) and periods.is_integer()):
95249520
raise ValueError("periods must be an integer")
95259521
periods = int(periods)
95269522

@@ -10412,8 +10408,13 @@ def _series_round(ser: Series, decimals: int) -> Series:
1041210408
new_cols = list(_dict_round(self, decimals))
1041310409
elif is_integer(decimals):
1041410410
# Dispatch to Block.round
10411+
# Argument "decimals" to "round" of "BaseBlockManager" has incompatible
10412+
# type "Union[int, integer[Any]]"; expected "int"
1041510413
return self._constructor(
10416-
self._mgr.round(decimals=decimals, using_cow=using_copy_on_write()),
10414+
self._mgr.round(
10415+
decimals=decimals, # type: ignore[arg-type]
10416+
using_cow=using_copy_on_write(),
10417+
),
1041710418
).__finalize__(self, method="round")
1041810419
else:
1041910420
raise TypeError("decimals must be an integer, a dict-like or a Series")

pandas/core/generic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4096,7 +4096,8 @@ class animal locomotion
40964096
loc, new_index = index._get_loc_level(key, level=0)
40974097
if not drop_level:
40984098
if lib.is_integer(loc):
4099-
new_index = index[loc : loc + 1]
4099+
# Slice index must be an integer or None
4100+
new_index = index[loc : loc + 1] # type: ignore[misc]
41004101
else:
41014102
new_index = index[loc]
41024103
else:

pandas/core/indexes/api.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@
7070

7171

7272
def get_objs_combined_axis(
73-
objs, intersect: bool = False, axis: Axis = 0, sort: bool = True, copy: bool = False
73+
objs,
74+
intersect: bool = False,
75+
axis: Axis = 0,
76+
sort: bool = True,
77+
copy: bool = False,
7478
) -> Index:
7579
"""
7680
Extract combined index: return intersection or union (depending on the

pandas/core/indexes/multi.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2700,6 +2700,7 @@ def _partial_tup_index(self, tup: tuple, side: Literal["left", "right"] = "left"
27002700
for k, (lab, lev, level_codes) in enumerate(zipped):
27012701
section = level_codes[start:end]
27022702

2703+
loc: npt.NDArray[np.intp] | np.intp | int
27032704
if lab not in lev and not isna(lab):
27042705
# short circuit
27052706
try:
@@ -2931,7 +2932,8 @@ def get_loc_level(self, key, level: IndexLabel = 0, drop_level: bool = True):
29312932
loc, mi = self._get_loc_level(key, level=level)
29322933
if not drop_level:
29332934
if lib.is_integer(loc):
2934-
mi = self[loc : loc + 1]
2935+
# Slice index must be an integer or None
2936+
mi = self[loc : loc + 1] # type: ignore[misc]
29352937
else:
29362938
mi = self[loc]
29372939
return loc, mi

pandas/core/indexes/period.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,7 @@ def _maybe_convert_timedelta(self, other) -> int | npt.NDArray[np.int64]:
303303

304304
raise raise_on_incompatible(self, other)
305305
elif is_integer(other):
306-
# integer is passed to .shift via
307-
# _add_datetimelike_methods basically
308-
# but ufunc may pass integer to _add_delta
306+
assert isinstance(other, int)
309307
return other
310308

311309
# raise when input doesn't have freq

pandas/core/indexing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1565,7 +1565,7 @@ def _is_scalar_access(self, key: tuple) -> bool:
15651565

15661566
return all(is_integer(k) for k in key)
15671567

1568-
def _validate_integer(self, key: int, axis: AxisInt) -> None:
1568+
def _validate_integer(self, key: int | np.integer, axis: AxisInt) -> None:
15691569
"""
15701570
Check that 'key' is a valid position in the desired axis.
15711571
@@ -2174,7 +2174,7 @@ def _ensure_iterable_column_indexer(self, column_indexer):
21742174
"""
21752175
Ensure that our column indexer is something that can be iterated over.
21762176
"""
2177-
ilocs: Sequence[int] | np.ndarray
2177+
ilocs: Sequence[int | np.integer] | np.ndarray
21782178
if is_integer(column_indexer):
21792179
ilocs = [column_indexer]
21802180
elif isinstance(column_indexer, slice):

pandas/core/reshape/concat.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ class _Concatenator:
391391
Orchestrates a concatenation operation for BlockManagers
392392
"""
393393

394+
sort: bool
395+
394396
def __init__(
395397
self,
396398
objs: Iterable[NDFrame] | Mapping[HashableT, NDFrame],
@@ -555,7 +557,9 @@ def __init__(
555557
raise ValueError(
556558
f"The 'sort' keyword only accepts boolean values; {sort} was passed."
557559
)
558-
self.sort = sort
560+
# Incompatible types in assignment (expression has type "Union[bool, bool_]",
561+
# variable has type "bool")
562+
self.sort = sort # type: ignore[assignment]
559563

560564
self.ignore_index = ignore_index
561565
self.verify_integrity = verify_integrity

pandas/core/reshape/merge.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2026,7 +2026,8 @@ def _get_merge_keys(
20262026
elif is_float_dtype(lt):
20272027
if not is_number(self.tolerance):
20282028
raise MergeError(msg)
2029-
if self.tolerance < 0:
2029+
# error: Unsupported operand types for > ("int" and "Number")
2030+
if self.tolerance < 0: # type: ignore[operator]
20302031
raise MergeError("tolerance must be positive")
20312032

20322033
else:

pandas/core/window/rolling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,7 @@ def _generate_cython_apply_func(
14021402
self,
14031403
args: tuple[Any, ...],
14041404
kwargs: dict[str, Any],
1405-
raw: bool,
1405+
raw: bool | np.bool_,
14061406
function: Callable[..., Any],
14071407
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int], np.ndarray]:
14081408
from pandas import Series

pandas/io/parsers/python_parser.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1372,4 +1372,5 @@ def _validate_skipfooter_arg(skipfooter: int) -> int:
13721372
if skipfooter < 0:
13731373
raise ValueError("skipfooter cannot be negative")
13741374

1375-
return skipfooter
1375+
# Incompatible return value type (got "Union[int, integer[Any]]", expected "int")
1376+
return skipfooter # type: ignore[return-value]

pandas/io/sql.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from pandas.core.dtypes.common import (
4545
is_datetime64tz_dtype,
4646
is_dict_like,
47-
is_integer,
4847
is_list_like,
4948
)
5049
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -1022,7 +1021,7 @@ def insert(
10221021
chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list))
10231022
num_inserted = exec_insert(conn, keys, chunk_iter)
10241023
# GH 46891
1025-
if is_integer(num_inserted):
1024+
if num_inserted is not None:
10261025
if total_inserted is None:
10271026
total_inserted = num_inserted
10281027
else:

pandas/tests/io/test_sql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def psql_insert_copy(table, conn, keys, data_iter):
775775
"test_frame", conn, index=False, method=psql_insert_copy
776776
)
777777
# GH 46891
778-
if not isinstance(expected_count, int):
778+
if expected_count is None:
779779
assert result_count is None
780780
else:
781781
assert result_count == expected_count

pandas/util/_validators.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def validate_bool_kwarg(
250250
"""
251251
good_value = is_bool(value)
252252
if none_allowed:
253-
good_value = good_value or value is None
253+
good_value = good_value or (value is None)
254254

255255
if int_allowed:
256256
good_value = good_value or isinstance(value, int)
@@ -260,7 +260,7 @@ def validate_bool_kwarg(
260260
f'For argument "{arg_name}" expected type bool, received '
261261
f"type {type(value).__name__}."
262262
)
263-
return value
263+
return value # pyright: ignore[reportGeneralTypeIssues]
264264

265265

266266
def validate_fillna_kwargs(value, method, validate_scalar_dict_value: bool = True):
@@ -438,7 +438,7 @@ def validate_insert_loc(loc: int, length: int) -> int:
438438
loc += length
439439
if not 0 <= loc <= length:
440440
raise IndexError(f"loc must be an integer between -{length} and {length}")
441-
return loc
441+
return loc # pyright: ignore[reportGeneralTypeIssues]
442442

443443

444444
def check_dtype_backend(dtype_backend) -> None:

0 commit comments

Comments
 (0)