Skip to content

TYP: try out TypeGuard #51309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 16, 2023
19 changes: 11 additions & 8 deletions pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO(npdtypes): Many types specified here can be made more specific/accurate;
# the more specific versions are specified in comments

from decimal import Decimal
from typing import (
Any,
Callable,
Expand All @@ -13,9 +13,12 @@ from typing import (

import numpy as np

from pandas._libs.interval import Interval
from pandas._libs.tslibs import Period
from pandas._typing import (
ArrayLike,
DtypeObj,
TypeGuard,
npt,
)

Expand All @@ -38,13 +41,13 @@ def infer_dtype(value: object, skipna: bool = ...) -> str: ...
def is_iterator(obj: object) -> bool: ...
def is_scalar(val: object) -> bool: ...
def is_list_like(obj: object, allow_sets: bool = ...) -> bool: ...
def is_period(val: object) -> bool: ...
def is_interval(val: object) -> bool: ...
def is_decimal(val: object) -> bool: ...
def is_complex(val: object) -> bool: ...
def is_bool(val: object) -> bool: ...
def is_integer(val: object) -> bool: ...
def is_float(val: object) -> bool: ...
def is_period(val: object) -> TypeGuard[Period]: ...
def is_interval(val: object) -> TypeGuard[Interval]: ...
def is_decimal(val: object) -> TypeGuard[Decimal]: ...
def is_complex(val: object) -> TypeGuard[complex]: ...
def is_bool(val: object) -> TypeGuard[bool | np.bool_]: ...
def is_integer(val: object) -> TypeGuard[int | np.integer]: ...
Copy link
Contributor

@topper-123 topper-123 Feb 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of dislike having composite types inside the type guard (because it probably gives some uglier code in other locations). Maybe only use TypeGuard when it gives a single type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i get where youre coming from, but i also have an urge for "maximum accuracy" here. im going to keep it how it is for now but will be OK with changing if consensus develops

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the Union (but it would be nicer without it). I think not having the Unions might cause some incorrect unreachable code warnings from mypy/pyright.

if is_integer(x):
    if isinstance(x, np.number): # I believe mypy would error here if it wouldn't return a Union
        x = x.item()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To some extent, this change is revealing an issue throughout the pandas code.

Sometimes, when we call is_integer(value) or is_bool(value), the intent is to check whether the value is an integer or numpy integer, or bool or numpy boolean. In other cases, we actually want to check that the value is NOT a numpy type (although I'm not 100% sure of that).

If it is the case that we accept a numpy integer/bool whenever we have typed an argument as int or bool, respectively, then the types used within the pandas functions should change to reflect that.

Copy link
Contributor

@topper-123 topper-123 Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we want to check that it is a python integer, we'd just do isinstance(x, int). My issue was more that by having a union type here, we'd have to have this union type in a lot of places, which doesn't feel like very clean code. My suggestion was to not type is_integer or other with union return types.

But ok, it was just a preference, I can also accept it, as overall this PR is a nice improvement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My issue was more that by having a union type here, we'd have to have this usion type in a lot of places, which doesn't feel like very clean code.

Or for every place that such a change is necessary, it makes you analyze the code and determine whether you really want to change the test to isinstance(x, int) or modify other parts of the code knowing that a np.integer is valid.

def is_float(val: object) -> TypeGuard[float]: ...
def is_interval_array(values: np.ndarray) -> bool: ...
def is_datetime64_array(values: np.ndarray) -> bool: ...
def is_timedelta_or_timedelta64_array(values: np.ndarray) -> bool: ...
Expand Down
6 changes: 6 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,19 @@
# Name "npt._ArrayLikeInt_co" is not defined [name-defined]
NumpySorter = Optional[npt._ArrayLikeInt_co] # type: ignore[name-defined]

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard # pyright: reportUnusedImport = false

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd import from standard lib if on python >= 3.10, else from typing_extensions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to take into account that typing_extensions may not be installed. All in all I'd probably do this something like this:

if PY310:
    from typing import TypeGuard
else:
   try:
        from typing_extensions import TypeGuard
    except ImportError:
        TypeGuard = bool

Also add typing_extensions to requirements-dev.txt , I think it's possible to add it like

typing_extensions; python_version < '3.10'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that work in an environment.yml file? requirements-dev.txt is auto-generated from the environment.yml.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking for py310 requires a circular import which isnt great. im leaning towards abandoning this PR, would be happy if you'd like to get it across the finish line

Copy link
Contributor

@topper-123 topper-123 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't know it is autogenerated and I not knowledgeable with environment.yml to know if it`s possible there. I could take a look. I can also try take over, no problem.

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self # pyright: reportUnusedImport = false
else:
npt: Any = None
Self: Any = None
TypeGuard: Any = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeGuard = bool? In my understanding a TypeGuard is like a bool that ensure that another has some type…

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably not be outside of the if TYPE_CHECKING section?


HashableT = TypeVar("HashableT", bound=Hashable)

Expand Down
5 changes: 4 additions & 1 deletion pandas/compat/numpy/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
overload,
)

import numpy as np
from numpy import ndarray

from pandas._libs.lib import (
Expand Down Expand Up @@ -215,7 +216,7 @@ def validate_clip_with_axis(
)


def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
"""
If this function is called via the 'numpy' library, the third parameter in
its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
Expand All @@ -224,6 +225,8 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
if not is_bool(skipna):
args = (skipna,) + args
skipna = True
elif isinstance(skipna, np.bool_):
skipna = bool(skipna)

validate_cum_func(args, kwargs, fname=name)
return skipna
Expand Down
1 change: 0 additions & 1 deletion pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,7 +2182,6 @@ def validate_periods(periods: int | float | None) -> int | None:
periods = int(periods)
elif not lib.is_integer(periods):
raise TypeError(f"periods must be a number, got {periods}")
periods = cast(int, periods)
return periods


Expand Down
10 changes: 2 additions & 8 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,9 @@ def maybe_box_native(value: Scalar | None | NAType) -> Scalar | None | NAType:
scalar or Series
"""
if is_float(value):
# error: Argument 1 to "float" has incompatible type
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
# expected "Union[SupportsFloat, _SupportsIndex, str]"
value = float(value) # type: ignore[arg-type]
value = float(value)
elif is_integer(value):
# error: Argument 1 to "int" has incompatible type
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
# expected "Union[str, SupportsInt, _SupportsIndex, _SupportsTrunc]"
value = int(value) # type: ignore[arg-type]
value = int(value)
elif is_bool(value):
value = bool(value)
elif isinstance(value, (np.datetime64, np.timedelta64)):
Expand Down
15 changes: 11 additions & 4 deletions pandas/core/dtypes/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
from collections import abc
from numbers import Number
import re
from typing import Pattern
from typing import (
TYPE_CHECKING,
Hashable,
Pattern,
)

import numpy as np

from pandas._libs import lib

if TYPE_CHECKING:
from pandas._typing import TypeGuard

is_bool = lib.is_bool

is_integer = lib.is_integer
Expand All @@ -30,7 +37,7 @@
is_iterator = lib.is_iterator


def is_number(obj) -> bool:
def is_number(obj) -> TypeGuard[Number | np.number]:
"""
Check if the object is a number.

Expand Down Expand Up @@ -132,7 +139,7 @@ def is_file_like(obj) -> bool:
return bool(hasattr(obj, "__iter__"))


def is_re(obj) -> bool:
def is_re(obj) -> TypeGuard[Pattern]:
"""
Check if the object is a regex pattern instance.

Expand Down Expand Up @@ -325,7 +332,7 @@ def is_named_tuple(obj) -> bool:
return isinstance(obj, abc.Sequence) and hasattr(obj, "_fields")


def is_hashable(obj) -> bool:
def is_hashable(obj) -> TypeGuard[Hashable]:
"""
Return True if hash(obj) will succeed, False otherwise.

Expand Down
13 changes: 7 additions & 6 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9501,11 +9501,7 @@ def melt(
)
def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame:
if not lib.is_integer(periods):
if not (
is_float(periods)
# error: "int" has no attribute "is_integer"
and periods.is_integer() # type: ignore[attr-defined]
):
if not (is_float(periods) and periods.is_integer()):
raise ValueError("periods must be an integer")
periods = int(periods)

Expand Down Expand Up @@ -10397,8 +10393,13 @@ def _series_round(ser: Series, decimals: int) -> Series:
new_cols = list(_dict_round(self, decimals))
elif is_integer(decimals):
# Dispatch to Block.round
# Argument "decimals" to "round" of "BaseBlockManager" has incompatible
# type "Union[int, integer[Any]]"; expected "int"
return self._constructor(
self._mgr.round(decimals=decimals, using_cow=using_copy_on_write()),
self._mgr.round(
decimals=decimals, # type: ignore[arg-type]
using_cow=using_copy_on_write(),
),
).__finalize__(self, method="round")
else:
raise TypeError("decimals must be an integer, a dict-like or a Series")
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4089,7 +4089,8 @@ class animal locomotion
loc, new_index = index._get_loc_level(key, level=0)
if not drop_level:
if lib.is_integer(loc):
new_index = index[loc : loc + 1]
# Slice index must be an integer or None
new_index = index[loc : loc + 1] # type: ignore[misc]
else:
new_index = index[loc]
else:
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/indexes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@


def get_objs_combined_axis(
objs, intersect: bool = False, axis: Axis = 0, sort: bool = True, copy: bool = False
objs,
intersect: bool = False,
axis: Axis = 0,
sort: bool = True,
copy: bool = False,
) -> Index:
"""
Extract combined index: return intersection or union (depending on the
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,7 @@ def _partial_tup_index(self, tup: tuple, side: Literal["left", "right"] = "left"
for k, (lab, lev, level_codes) in enumerate(zipped):
section = level_codes[start:end]

loc: npt.NDArray[np.intp] | np.intp | int
if lab not in lev and not isna(lab):
# short circuit
try:
Expand Down Expand Up @@ -2930,7 +2931,8 @@ def get_loc_level(self, key, level: IndexLabel = 0, drop_level: bool = True):
loc, mi = self._get_loc_level(key, level=level)
if not drop_level:
if lib.is_integer(loc):
mi = self[loc : loc + 1]
# Slice index must be an integer or None
mi = self[loc : loc + 1] # type: ignore[misc]
else:
mi = self[loc]
return loc, mi
Expand Down
4 changes: 1 addition & 3 deletions pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,7 @@ def _maybe_convert_timedelta(self, other) -> int | npt.NDArray[np.int64]:

raise raise_on_incompatible(self, other)
elif is_integer(other):
# integer is passed to .shift via
# _add_datetimelike_methods basically
# but ufunc may pass integer to _add_delta
assert isinstance(other, int)
return other

# raise when input doesn't have freq
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ def _is_scalar_access(self, key: tuple) -> bool:

return all(is_integer(k) for k in key)

def _validate_integer(self, key: int, axis: AxisInt) -> None:
def _validate_integer(self, key: int | np.integer, axis: AxisInt) -> None:
"""
Check that 'key' is a valid position in the desired axis.

Expand Down Expand Up @@ -2171,7 +2171,7 @@ def _ensure_iterable_column_indexer(self, column_indexer):
"""
Ensure that our column indexer is something that can be iterated over.
"""
ilocs: Sequence[int] | np.ndarray
ilocs: Sequence[int | np.integer] | np.ndarray
if is_integer(column_indexer):
ilocs = [column_indexer]
elif isinstance(column_indexer, slice):
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ class _Concatenator:
Orchestrates a concatenation operation for BlockManagers
"""

sort: bool

def __init__(
self,
objs: Iterable[NDFrame] | Mapping[HashableT, NDFrame],
Expand Down Expand Up @@ -555,7 +557,9 @@ def __init__(
raise ValueError(
f"The 'sort' keyword only accepts boolean values; {sort} was passed."
)
self.sort = sort
# Incompatible types in assignment (expression has type "Union[bool, bool_]",
# variable has type "bool")
self.sort = sort # type: ignore[assignment]

self.ignore_index = ignore_index
self.verify_integrity = verify_integrity
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,7 +2026,8 @@ def _get_merge_keys(
elif is_float_dtype(lt):
if not is_number(self.tolerance):
raise MergeError(msg)
if self.tolerance < 0:
# error: Unsupported operand types for > ("int" and "Number")
if self.tolerance < 0: # type: ignore[operator]
raise MergeError("tolerance must be positive")

else:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,7 @@ def _generate_cython_apply_func(
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
raw: bool,
raw: bool | np.bool_,
function: Callable[..., Any],
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int], np.ndarray]:
from pandas import Series
Expand Down
3 changes: 2 additions & 1 deletion pandas/io/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,4 +1348,5 @@ def _validate_skipfooter_arg(skipfooter: int) -> int:
if skipfooter < 0:
raise ValueError("skipfooter cannot be negative")

return skipfooter
# Incompatible return value type (got "Union[int, integer[Any]]", expected "int")
return skipfooter # type: ignore[return-value]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could remove # type: ignore by changing signature:

def _validate_skipfooter_arg(skipfooter: int | np.integer) -> int | np.integer:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not good to change type signatures to keep the type checkers happy, mostly because it could easily spread to be needed everywhere, which would be a PITA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not good to change type signatures to keep the type checkers happy, mostly because it could easily spread to be needed everywhere, which would be a PITA.

But that's part of the point of type checking! The fact that you had to do a # type: ignore is a signal that the code is inconsistent in terms of the expected arguments. In this particular case, the function is called only once, so changing the type in the function signature won't matter. Not sure of the other cases that I noted elsewhere in the PR.

3 changes: 1 addition & 2 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from pandas.core.dtypes.common import (
is_datetime64tz_dtype,
is_dict_like,
is_integer,
is_list_like,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
Expand Down Expand Up @@ -1022,7 +1021,7 @@ def insert(
chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list))
num_inserted = exec_insert(conn, keys, chunk_iter)
# GH 46891
if is_integer(num_inserted):
if num_inserted is not None:
if total_inserted is None:
total_inserted = num_inserted
else:
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def psql_insert_copy(table, conn, keys, data_iter):
"test_frame", conn, index=False, method=psql_insert_copy
)
# GH 46891
if not isinstance(expected_count, int):
if expected_count is None:
assert result_count is None
else:
assert result_count == expected_count
Expand Down
6 changes: 3 additions & 3 deletions pandas/util/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def validate_bool_kwarg(
"""
good_value = is_bool(value)
if none_allowed:
good_value = good_value or value is None
good_value = good_value or (value is None)

if int_allowed:
good_value = good_value or isinstance(value, int)
Expand All @@ -260,7 +260,7 @@ def validate_bool_kwarg(
f'For argument "{arg_name}" expected type bool, received '
f"type {type(value).__name__}."
)
return value
return value # pyright: ignore[reportGeneralTypeIssues]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also add this file to the long list here https://github.com/pandas-dev/pandas/blob/main/pyright_reportGeneralTypeIssues.json#L17

But I wouldn't mind start using line-by-line ignores. I think we started with file-based ignores as we have in most cases just too many errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things like this will show up many place where we use is_integer (after this PR), so IMO it file based exclusions will not work because every time we use is_integer in new files, these typing errors may show up. So using file based ignores will quickly make using pyright less meaningful.

So I think we will have to either:

  • use line based ignores (allows the widest use of TypeGuard)
  • only use a single type inside TypeGuard (minimizes issues when using functions that implement TypeGuard)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get rid of the error by changing the sig of the function:

def validate_bool_kwarg(
    value: bool | int | None | np.bool_, arg_name, none_allowed: bool = True, int_allowed: bool = False
) -> bool | int | None | np.bool_:

In this case, the type checker was telling you that np.bool_ is a valid type for the value argument.



def validate_fillna_kwargs(value, method, validate_scalar_dict_value: bool = True):
Expand Down Expand Up @@ -438,7 +438,7 @@ def validate_insert_loc(loc: int, length: int) -> int:
loc += length
if not 0 <= loc <= length:
raise IndexError(f"loc must be an integer between -{length} and {length}")
return loc
return loc # pyright: ignore[reportGeneralTypeIssues]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the sig here too:

def validate_insert_loc(loc: int | np.integer, length: int) -> int | np.integer:

Note - if you really don't want to allow np.integer, then change the test inside to not use is_integer() .



def check_dtype_backend(dtype_backend) -> None:
Expand Down