Skip to content

Commit 0200a9e

Browse files
twoertweinpmhatre1
authored andcommitted
TYP: Update pyright (pandas-dev#56892)
* TYP: Update pyright * isort * int | float
1 parent 667418e commit 0200a9e

File tree

19 files changed

+121
-45
lines changed

19 files changed

+121
-45
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ repos:
132132
types: [python]
133133
stages: [manual]
134134
additional_dependencies: &pyright_dependencies
135-
135+
136136
- id: pyright
137137
# note: assumes python env is setup and activated
138138
name: pyright reportGeneralTypeIssues

pandas/_config/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class DeprecatedOption(NamedTuple):
8888

8989
class RegisteredOption(NamedTuple):
9090
key: str
91-
defval: object
91+
defval: Any
9292
doc: str
9393
validator: Callable[[object], Any] | None
9494
cb: Callable[[str], Any] | None

pandas/_config/localization.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import platform
1111
import re
1212
import subprocess
13-
from typing import TYPE_CHECKING
13+
from typing import (
14+
TYPE_CHECKING,
15+
cast,
16+
)
1417

1518
from pandas._config.config import options
1619

@@ -152,7 +155,7 @@ def get_locales(
152155
out_locales = []
153156
for x in split_raw_locales:
154157
try:
155-
out_locales.append(str(x, encoding=options.display.encoding))
158+
out_locales.append(str(x, encoding=cast(str, options.display.encoding)))
156159
except UnicodeError:
157160
# 'locale -a' is used to populated 'raw_locales' and on
158161
# Redhat 7 Linux (and maybe others) prints locale names

pandas/_libs/lib.pyi

+25-3
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,26 @@ def fast_multiget(
6969
mapping: dict,
7070
keys: np.ndarray, # object[:]
7171
default=...,
72-
) -> np.ndarray: ...
72+
) -> ArrayLike: ...
7373
def fast_unique_multiple_list_gen(gen: Generator, sort: bool = ...) -> list: ...
7474
def fast_unique_multiple_list(lists: list, sort: bool | None = ...) -> list: ...
75+
@overload
7576
def map_infer(
7677
arr: np.ndarray,
7778
f: Callable[[Any], Any],
78-
convert: bool = ...,
79+
*,
80+
convert: Literal[False],
7981
ignore_na: bool = ...,
8082
) -> np.ndarray: ...
8183
@overload
84+
def map_infer(
85+
arr: np.ndarray,
86+
f: Callable[[Any], Any],
87+
*,
88+
convert: bool = ...,
89+
ignore_na: bool = ...,
90+
) -> ArrayLike: ...
91+
@overload
8292
def maybe_convert_objects(
8393
objects: npt.NDArray[np.object_],
8494
*,
@@ -164,14 +174,26 @@ def is_all_arraylike(obj: list) -> bool: ...
164174
# Functions which in reality take memoryviews
165175

166176
def memory_usage_of_objects(arr: np.ndarray) -> int: ... # object[:] # np.int64
177+
@overload
167178
def map_infer_mask(
168179
arr: np.ndarray,
169180
f: Callable[[Any], Any],
170181
mask: np.ndarray, # const uint8_t[:]
171-
convert: bool = ...,
182+
*,
183+
convert: Literal[False],
172184
na_value: Any = ...,
173185
dtype: np.dtype = ...,
174186
) -> np.ndarray: ...
187+
@overload
188+
def map_infer_mask(
189+
arr: np.ndarray,
190+
f: Callable[[Any], Any],
191+
mask: np.ndarray, # const uint8_t[:]
192+
*,
193+
convert: bool = ...,
194+
na_value: Any = ...,
195+
dtype: np.dtype = ...,
196+
) -> ArrayLike: ...
175197
def indices_fast(
176198
index: npt.NDArray[np.intp],
177199
labels: np.ndarray, # const int64_t[:]

pandas/_libs/lib.pyx

+7-6
Original file line numberDiff line numberDiff line change
@@ -2864,10 +2864,11 @@ def map_infer_mask(
28642864
ndarray[object] arr,
28652865
object f,
28662866
const uint8_t[:] mask,
2867+
*,
28672868
bint convert=True,
28682869
object na_value=no_default,
28692870
cnp.dtype dtype=np.dtype(object)
2870-
) -> np.ndarray:
2871+
) -> "ArrayLike":
28712872
"""
28722873
Substitute for np.vectorize with pandas-friendly dtype inference.
28732874

@@ -2887,7 +2888,7 @@ def map_infer_mask(
28872888

28882889
Returns
28892890
-------
2890-
np.ndarray
2891+
np.ndarray or an ExtensionArray
28912892
"""
28922893
cdef Py_ssize_t n = len(arr)
28932894
result = np.empty(n, dtype=dtype)
@@ -2941,8 +2942,8 @@ def _map_infer_mask(
29412942
@cython.boundscheck(False)
29422943
@cython.wraparound(False)
29432944
def map_infer(
2944-
ndarray arr, object f, bint convert=True, bint ignore_na=False
2945-
) -> np.ndarray:
2945+
ndarray arr, object f, *, bint convert=True, bint ignore_na=False
2946+
) -> "ArrayLike":
29462947
"""
29472948
Substitute for np.vectorize with pandas-friendly dtype inference.
29482949

@@ -2956,7 +2957,7 @@ def map_infer(
29562957

29572958
Returns
29582959
-------
2959-
np.ndarray
2960+
np.ndarray or an ExtensionArray
29602961
"""
29612962
cdef:
29622963
Py_ssize_t i, n
@@ -3091,7 +3092,7 @@ def to_object_array_tuples(rows: object) -> np.ndarray:
30913092

30923093
@cython.wraparound(False)
30933094
@cython.boundscheck(False)
3094-
def fast_multiget(dict mapping, object[:] keys, default=np.nan) -> np.ndarray:
3095+
def fast_multiget(dict mapping, object[:] keys, default=np.nan) -> "ArrayLike":
30953096
cdef:
30963097
Py_ssize_t i, n = len(keys)
30973098
object val

pandas/core/arrays/string_.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
TYPE_CHECKING,
55
ClassVar,
66
Literal,
7+
cast,
78
)
89

910
import numpy as np
@@ -637,7 +638,7 @@ def _str_map(
637638
# error: Argument 1 to "dtype" has incompatible type
638639
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
639640
# "Type[object]"
640-
dtype=np.dtype(dtype), # type: ignore[arg-type]
641+
dtype=np.dtype(cast(type, dtype)),
641642
)
642643

643644
if not na_value_is_na:

pandas/core/arrays/string_arrow.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TYPE_CHECKING,
88
Callable,
99
Union,
10+
cast,
1011
)
1112
import warnings
1213

@@ -327,7 +328,7 @@ def _str_map(
327328
# error: Argument 1 to "dtype" has incompatible type
328329
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
329330
# "Type[object]"
330-
dtype=np.dtype(dtype), # type: ignore[arg-type]
331+
dtype=np.dtype(cast(type, dtype)),
331332
)
332333

333334
if not na_value_is_na:
@@ -640,7 +641,7 @@ def _str_map(
640641
mask.view("uint8"),
641642
convert=False,
642643
na_value=na_value,
643-
dtype=np.dtype(dtype), # type: ignore[arg-type]
644+
dtype=np.dtype(cast(type, dtype)),
644645
)
645646
return result
646647

pandas/core/computation/eval.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from __future__ import annotations
55

66
import tokenize
7-
from typing import TYPE_CHECKING
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
)
811
import warnings
912

1013
from pandas.util._exceptions import find_stack_level
@@ -177,7 +180,7 @@ def eval(
177180
level: int = 0,
178181
target=None,
179182
inplace: bool = False,
180-
):
183+
) -> Any:
181184
"""
182185
Evaluate a Python expression as a string using various backends.
183186

pandas/core/dtypes/missing.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def _use_inf_as_na(key) -> None:
258258
globals()["INF_AS_NA"] = False
259259

260260

261-
def _isna_array(values: ArrayLike, inf_as_na: bool = False):
261+
def _isna_array(
262+
values: ArrayLike, inf_as_na: bool = False
263+
) -> npt.NDArray[np.bool_] | NDFrame:
262264
"""
263265
Return an array indicating which values of the input array are NaN / NA.
264266
@@ -275,6 +277,7 @@ def _isna_array(values: ArrayLike, inf_as_na: bool = False):
275277
Array of boolean values denoting the NA status of each element.
276278
"""
277279
dtype = values.dtype
280+
result: npt.NDArray[np.bool_] | NDFrame
278281

279282
if not isinstance(values, np.ndarray):
280283
# i.e. ExtensionArray

pandas/core/frame.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9802,7 +9802,9 @@ def explode(
98029802

98039803
return result.__finalize__(self, method="explode")
98049804

9805-
def unstack(self, level: IndexLabel = -1, fill_value=None, sort: bool = True):
9805+
def unstack(
9806+
self, level: IndexLabel = -1, fill_value=None, sort: bool = True
9807+
) -> DataFrame | Series:
98069808
"""
98079809
Pivot a level of the (necessarily hierarchical) index labels.
98089810

pandas/core/indexes/range.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def copy(self, name: Hashable | None = None, deep: bool = False) -> Self:
491491
new_index = self._rename(name=name)
492492
return new_index
493493

494-
def _minmax(self, meth: str):
494+
def _minmax(self, meth: str) -> int | float:
495495
no_steps = len(self) - 1
496496
if no_steps == -1:
497497
return np.nan
@@ -500,13 +500,13 @@ def _minmax(self, meth: str):
500500

501501
return self.start + self.step * no_steps
502502

503-
def min(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
503+
def min(self, axis=None, skipna: bool = True, *args, **kwargs) -> int | float:
504504
"""The minimum value of the RangeIndex"""
505505
nv.validate_minmax_axis(axis)
506506
nv.validate_min(args, kwargs)
507507
return self._minmax("min")
508508

509-
def max(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
509+
def max(self, axis=None, skipna: bool = True, *args, **kwargs) -> int | float:
510510
"""The maximum value of the RangeIndex"""
511511
nv.validate_minmax_axis(axis)
512512
nv.validate_max(args, kwargs)

pandas/core/ops/common.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def wrapper(method: F) -> F:
4040
return wrapper
4141

4242

43-
def _unpack_zerodim_and_defer(method, name: str):
43+
def _unpack_zerodim_and_defer(method: F, name: str) -> F:
4444
"""
4545
Boilerplate for pandas conventions in arithmetic and comparison methods.
4646
@@ -75,7 +75,9 @@ def new_method(self, other):
7575

7676
return method(self, other)
7777

78-
return new_method
78+
# error: Incompatible return value type (got "Callable[[Any, Any], Any]",
79+
# expected "F")
80+
return new_method # type: ignore[return-value]
7981

8082

8183
def get_op_result_name(left, right):

pandas/core/reshape/pivot.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,8 @@ def pivot(
568568
# error: Argument 1 to "unstack" of "DataFrame" has incompatible type "Union
569569
# [List[Any], ExtensionArray, ndarray[Any, Any], Index, Series]"; expected
570570
# "Hashable"
571-
result = indexed.unstack(columns_listlike) # type: ignore[arg-type]
571+
# unstack with a MultiIndex returns a DataFrame
572+
result = cast("DataFrame", indexed.unstack(columns_listlike)) # type: ignore[arg-type]
572573
result.index.names = [
573574
name if name is not lib.no_default else None for name in result.index.names
574575
]

0 commit comments

Comments
 (0)