Skip to content

TYP: Update pyright #56892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ repos:
types: [python]
stages: [manual]
additional_dependencies: &pyright_dependencies
- [email protected].339
- [email protected].347
- id: pyright
# note: assumes python env is setup and activated
name: pyright reportGeneralTypeIssues
Expand Down
2 changes: 1 addition & 1 deletion pandas/_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class DeprecatedOption(NamedTuple):

class RegisteredOption(NamedTuple):
key: str
defval: object
defval: Any
doc: str
validator: Callable[[object], Any] | None
cb: Callable[[str], Any] | None
Expand Down
7 changes: 5 additions & 2 deletions pandas/_config/localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import platform
import re
import subprocess
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
cast,
)

from pandas._config.config import options

Expand Down Expand Up @@ -152,7 +155,7 @@ def get_locales(
out_locales = []
for x in split_raw_locales:
try:
out_locales.append(str(x, encoding=options.display.encoding))
out_locales.append(str(x, encoding=cast(str, options.display.encoding)))
except UnicodeError:
# 'locale -a' is used to populated 'raw_locales' and on
# Redhat 7 Linux (and maybe others) prints locale names
Expand Down
28 changes: 25 additions & 3 deletions pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,26 @@ def fast_multiget(
mapping: dict,
keys: np.ndarray, # object[:]
default=...,
) -> np.ndarray: ...
) -> ArrayLike: ...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe_convert_objects returns an ExtensionArray or an np.ndarray. Most functions call maybe_convert_objects only when convert=True and otherwise return an np.ndarray (see the overloads).

def fast_unique_multiple_list_gen(gen: Generator, sort: bool = ...) -> list: ...
def fast_unique_multiple_list(lists: list, sort: bool | None = ...) -> list: ...
@overload
def map_infer(
arr: np.ndarray,
f: Callable[[Any], Any],
convert: bool = ...,
*,
convert: Literal[False],
ignore_na: bool = ...,
) -> np.ndarray: ...
@overload
def map_infer(
arr: np.ndarray,
f: Callable[[Any], Any],
*,
convert: bool = ...,
ignore_na: bool = ...,
) -> ArrayLike: ...
@overload
def maybe_convert_objects(
objects: npt.NDArray[np.object_],
*,
Expand Down Expand Up @@ -164,14 +174,26 @@ def is_all_arraylike(obj: list) -> bool: ...
# Functions which in reality take memoryviews

def memory_usage_of_objects(arr: np.ndarray) -> int: ... # object[:] # np.int64
@overload
def map_infer_mask(
arr: np.ndarray,
f: Callable[[Any], Any],
mask: np.ndarray, # const uint8_t[:]
convert: bool = ...,
*,
convert: Literal[False],
na_value: Any = ...,
dtype: np.dtype = ...,
) -> np.ndarray: ...
@overload
def map_infer_mask(
arr: np.ndarray,
f: Callable[[Any], Any],
mask: np.ndarray, # const uint8_t[:]
*,
convert: bool = ...,
na_value: Any = ...,
dtype: np.dtype = ...,
) -> ArrayLike: ...
def indices_fast(
index: npt.NDArray[np.intp],
labels: np.ndarray, # const int64_t[:]
Expand Down
13 changes: 7 additions & 6 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2864,10 +2864,11 @@ def map_infer_mask(
ndarray[object] arr,
object f,
const uint8_t[:] mask,
*,
bint convert=True,
object na_value=no_default,
cnp.dtype dtype=np.dtype(object)
) -> np.ndarray:
) -> "ArrayLike":
"""
Substitute for np.vectorize with pandas-friendly dtype inference.

Expand All @@ -2887,7 +2888,7 @@ def map_infer_mask(

Returns
-------
np.ndarray
np.ndarray or an ExtensionArray
"""
cdef Py_ssize_t n = len(arr)
result = np.empty(n, dtype=dtype)
Expand Down Expand Up @@ -2941,8 +2942,8 @@ def _map_infer_mask(
@cython.boundscheck(False)
@cython.wraparound(False)
def map_infer(
ndarray arr, object f, bint convert=True, bint ignore_na=False
) -> np.ndarray:
ndarray arr, object f, *, bint convert=True, bint ignore_na=False
) -> "ArrayLike":
"""
Substitute for np.vectorize with pandas-friendly dtype inference.

Expand All @@ -2956,7 +2957,7 @@ def map_infer(

Returns
-------
np.ndarray
np.ndarray or an ExtensionArray
"""
cdef:
Py_ssize_t i, n
Expand Down Expand Up @@ -3091,7 +3092,7 @@ def to_object_array_tuples(rows: object) -> np.ndarray:

@cython.wraparound(False)
@cython.boundscheck(False)
def fast_multiget(dict mapping, object[:] keys, default=np.nan) -> np.ndarray:
def fast_multiget(dict mapping, object[:] keys, default=np.nan) -> "ArrayLike":
cdef:
Py_ssize_t i, n = len(keys)
object val
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TYPE_CHECKING,
ClassVar,
Literal,
cast,
)

import numpy as np
Expand Down Expand Up @@ -637,7 +638,7 @@ def _str_map(
# error: Argument 1 to "dtype" has incompatible type
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
# "Type[object]"
dtype=np.dtype(dtype), # type: ignore[arg-type]
dtype=np.dtype(cast(type, dtype)),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map_infer_mask requires that dtype is np.dtype but np.dtype is a generic type: since the call to it seems invalid to mypy, it cannot infer the correct generic type of np.dtype which creates the existing type error. This PR uses overloads for map_infer_mask (convert=False returns np.ndarray): since the call to it seems invalid, mypy picks the last overload, which is not correct here - so I cast the argument for np.dtype to avoid this mess.

)

if not na_value_is_na:
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TYPE_CHECKING,
Callable,
Union,
cast,
)
import warnings

Expand Down Expand Up @@ -327,7 +328,7 @@ def _str_map(
# error: Argument 1 to "dtype" has incompatible type
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
# "Type[object]"
dtype=np.dtype(dtype), # type: ignore[arg-type]
dtype=np.dtype(cast(type, dtype)),
)

if not na_value_is_na:
Expand Down Expand Up @@ -640,7 +641,7 @@ def _str_map(
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(dtype), # type: ignore[arg-type]
dtype=np.dtype(cast(type, dtype)),
)
return result

Expand Down
7 changes: 5 additions & 2 deletions pandas/core/computation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from __future__ import annotations

import tokenize
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
)
import warnings

from pandas.util._exceptions import find_stack_level
Expand Down Expand Up @@ -177,7 +180,7 @@ def eval(
level: int = 0,
target=None,
inplace: bool = False,
):
) -> Any:
"""
Evaluate a Python expression as a string using various backends.

Expand Down
5 changes: 4 additions & 1 deletion pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def _use_inf_as_na(key) -> None:
globals()["INF_AS_NA"] = False


def _isna_array(values: ArrayLike, inf_as_na: bool = False):
def _isna_array(
values: ArrayLike, inf_as_na: bool = False
) -> npt.NDArray[np.bool_] | NDFrame:
"""
Return an array indicating which values of the input array are NaN / NA.

Expand All @@ -275,6 +277,7 @@ def _isna_array(values: ArrayLike, inf_as_na: bool = False):
Array of boolean values denoting the NA status of each element.
"""
dtype = values.dtype
result: npt.NDArray[np.bool_] | NDFrame

if not isinstance(values, np.ndarray):
# i.e. ExtensionArray
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9802,7 +9802,9 @@ def explode(

return result.__finalize__(self, method="explode")

def unstack(self, level: IndexLabel = -1, fill_value=None, sort: bool = True):
def unstack(
self, level: IndexLabel = -1, fill_value=None, sort: bool = True
) -> DataFrame | Series:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns a Series when a DataFrame has a non-MultiIndex index (this is also documented; a series with a non-MultiIndex raises instead). Pyright's magic was able to infer that DataFrame | Series can be returned, unfortunately, this requires quite a few ignore codes.

"""
Pivot a level of the (necessarily hierarchical) index labels.

Expand Down
6 changes: 3 additions & 3 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def copy(self, name: Hashable | None = None, deep: bool = False) -> Self:
new_index = self._rename(name=name)
return new_index

def _minmax(self, meth: str):
def _minmax(self, meth: str) -> int | float:
no_steps = len(self) - 1
if no_steps == -1:
return np.nan
Expand All @@ -500,13 +500,13 @@ def _minmax(self, meth: str):

return self.start + self.step * no_steps

def min(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
def min(self, axis=None, skipna: bool = True, *args, **kwargs) -> int | float:
"""The minimum value of the RangeIndex"""
nv.validate_minmax_axis(axis)
nv.validate_min(args, kwargs)
return self._minmax("min")

def max(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
def max(self, axis=None, skipna: bool = True, *args, **kwargs) -> int | float:
"""The maximum value of the RangeIndex"""
nv.validate_minmax_axis(axis)
nv.validate_max(args, kwargs)
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def wrapper(method: F) -> F:
return wrapper


def _unpack_zerodim_and_defer(method, name: str):
def _unpack_zerodim_and_defer(method: F, name: str) -> F:
"""
Boilerplate for pandas conventions in arithmetic and comparison methods.

Expand Down Expand Up @@ -75,7 +75,9 @@ def new_method(self, other):

return method(self, other)

return new_method
# error: Incompatible return value type (got "Callable[[Any, Any], Any]",
# expected "F")
return new_method # type: ignore[return-value]


def get_op_result_name(left, right):
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/reshape/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,8 @@ def pivot(
# error: Argument 1 to "unstack" of "DataFrame" has incompatible type "Union
# [List[Any], ExtensionArray, ndarray[Any, Any], Index, Series]"; expected
# "Hashable"
result = indexed.unstack(columns_listlike) # type: ignore[arg-type]
# unstack with a MultiIndex returns a DataFrame
result = cast("DataFrame", indexed.unstack(columns_listlike)) # type: ignore[arg-type]
result.index.names = [
name if name is not lib.no_default else None for name in result.index.names
]
Expand Down
Loading