Skip to content

TYP: simple return types #54786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 31, 2023
2 changes: 1 addition & 1 deletion pandas/_testing/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def round_trip_localpath(writer, reader, path: str | None = None):
return obj


def write_to_compressed(compression, path, data, dest: str = "test"):
def write_to_compressed(compression, path, data, dest: str = "test") -> None:
"""
Write data to a compressed file.

Expand Down
2 changes: 1 addition & 1 deletion pandas/arrays/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]


def __getattr__(name: str):
def __getattr__(name: str) -> type[NumpyExtensionArray]:
if name == "PandasArray":
# GH#53694
import warnings
Expand Down
2 changes: 1 addition & 1 deletion pandas/compat/pickle_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections.abc import Generator


def load_reduce(self):
def load_reduce(self) -> None:
stack = self.stack
args = stack.pop()
func = stack[-1]
Expand Down
4 changes: 2 additions & 2 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def box_with_array(request):


@pytest.fixture
def dict_subclass():
def dict_subclass() -> type[dict]:
"""
Fixture for a dictionary subclass.
"""
Expand All @@ -504,7 +504,7 @@ def __init__(self, *args, **kwargs) -> None:


@pytest.fixture
def non_dict_mapping_subclass():
def non_dict_mapping_subclass() -> type[abc.Mapping]:
"""
Fixture for a non-mapping dictionary subclass.
"""
Expand Down
16 changes: 9 additions & 7 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@

if TYPE_CHECKING:
from collections.abc import (
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)

Expand Down Expand Up @@ -253,7 +253,7 @@ def transform(self) -> DataFrame | Series:

return result

def transform_dict_like(self, func):
def transform_dict_like(self, func) -> DataFrame:
"""
Compute transform in the case of a dict-like func
"""
Expand Down Expand Up @@ -330,7 +330,7 @@ def compute_list_like(

Returns
-------
keys : list[hashable]
keys : list[Hashable] or Index
Index labels for result.
results : list
Data for result. When aggregating with a Series, this can contain any
Expand Down Expand Up @@ -370,7 +370,9 @@ def compute_list_like(
new_res = getattr(colg, op_name)(func, *args, **kwargs)
results.append(new_res)
indices.append(index)
keys = selected_obj.columns.take(indices)
# error: Incompatible types in assignment (expression has type "Any |
# Index", variable has type "list[Any | Callable[..., Any] | str]")
keys = selected_obj.columns.take(indices) # type: ignore[assignment]

return keys, results

Expand Down Expand Up @@ -772,7 +774,7 @@ def result_columns(self) -> Index:

@property
@abc.abstractmethod
def series_generator(self) -> Iterator[Series]:
def series_generator(self) -> Generator[Series, None, None]:
pass

@abc.abstractmethod
Expand Down Expand Up @@ -1014,7 +1016,7 @@ class FrameRowApply(FrameApply):
axis: AxisInt = 0

@property
def series_generator(self):
def series_generator(self) -> Generator[Series, None, None]:
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))

@property
Expand Down Expand Up @@ -1075,7 +1077,7 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
return result.T

@property
def series_generator(self):
def series_generator(self) -> Generator[Series, None, None]:
values = self.values
values = ensure_wrapped_if_datetimelike(values)
assert len(values) > 0
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/arrow/extension_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __ne__(self, other) -> bool:
def __hash__(self) -> int:
return hash((str(self), self.freq))

def to_pandas_dtype(self):
def to_pandas_dtype(self) -> PeriodDtype:
return PeriodDtype(freq=self.freq)


Expand Down Expand Up @@ -105,7 +105,7 @@ def __ne__(self, other) -> bool:
def __hash__(self) -> int:
return hash((str(self), str(self.subtype), self.closed))

def to_pandas_dtype(self):
def to_pandas_dtype(self) -> IntervalDtype:
return IntervalDtype(self.subtype.to_pandas_dtype(), self.closed)


Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,7 +2410,7 @@ def _mode(self, dropna: bool = True) -> Categorical:
# ------------------------------------------------------------------
# ExtensionArray Interface

def unique(self):
def unique(self) -> Self:
"""
Return the ``Categorical`` which ``categories`` and ``codes`` are
unique.
Expand Down
38 changes: 22 additions & 16 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
)


IntervalSideT = Union[TimeArrayLike, np.ndarray]
IntervalSide = Union[TimeArrayLike, np.ndarray]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...T is often used for TypeVars

IntervalOrNA = Union[Interval, float]

_interval_shared_docs: dict[str, str] = {}
Expand Down Expand Up @@ -219,8 +219,8 @@ def ndim(self) -> Literal[1]:
return 1

# To make mypy recognize the fields
_left: IntervalSideT
_right: IntervalSideT
_left: IntervalSide
_right: IntervalSide
_dtype: IntervalDtype

# ---------------------------------------------------------------------
Expand All @@ -237,8 +237,8 @@ def __new__(
data = extract_array(data, extract_numpy=True)

if isinstance(data, cls):
left: IntervalSideT = data._left
right: IntervalSideT = data._right
left: IntervalSide = data._left
right: IntervalSide = data._right
closed = closed or data.closed
dtype = IntervalDtype(left.dtype, closed=closed)
else:
Expand Down Expand Up @@ -280,8 +280,8 @@ def __new__(
@classmethod
def _simple_new(
cls,
left: IntervalSideT,
right: IntervalSideT,
left: IntervalSide,
right: IntervalSide,
dtype: IntervalDtype,
) -> Self:
result = IntervalMixin.__new__(cls)
Expand All @@ -299,7 +299,7 @@ def _ensure_simple_new_inputs(
closed: IntervalClosedType | None = None,
copy: bool = False,
dtype: Dtype | None = None,
) -> tuple[IntervalSideT, IntervalSideT, IntervalDtype]:
) -> tuple[IntervalSide, IntervalSide, IntervalDtype]:
"""Ensure correctness of input parameters for cls._simple_new."""
from pandas.core.indexes.base import ensure_index

Expand Down Expand Up @@ -1038,8 +1038,8 @@ def _concat_same_type(cls, to_concat: Sequence[IntervalArray]) -> Self:
raise ValueError("Intervals must all be closed on the same side.")
closed = closed_set.pop()

left = np.concatenate([interval.left for interval in to_concat])
right = np.concatenate([interval.right for interval in to_concat])
left: IntervalSide = np.concatenate([interval.left for interval in to_concat])
right: IntervalSide = np.concatenate([interval.right for interval in to_concat])

left, right, dtype = cls._ensure_simple_new_inputs(left, right, closed=closed)

Expand Down Expand Up @@ -1290,7 +1290,7 @@ def _format_space(self) -> str:
# Vectorized Interval Properties/Attributes

@property
def left(self):
def left(self) -> Index:
"""
Return the left endpoints of each Interval in the IntervalArray as an Index.

Expand All @@ -1310,7 +1310,7 @@ def left(self):
return Index(self._left, copy=False)

@property
def right(self):
def right(self) -> Index:
"""
Return the right endpoints of each Interval in the IntervalArray as an Index.

Expand Down Expand Up @@ -1862,11 +1862,17 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
return isin(self.astype(object), values.astype(object))

@property
def _combined(self) -> IntervalSideT:
left = self.left._values.reshape(-1, 1)
right = self.right._values.reshape(-1, 1)
def _combined(self) -> IntervalSide:
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
# has no attribute "reshape" [union-attr]
left = self.left._values.reshape(-1, 1) # type: ignore[union-attr]
right = self.right._values.reshape(-1, 1) # type: ignore[union-attr]
if needs_i8_conversion(left.dtype):
comb = left._concat_same_type([left, right], axis=1)
# error: Item "ndarray[Any, Any]" of "Any | ndarray[Any, Any]" has
# no attribute "_concat_same_type"
comb = left._concat_same_type( # type: ignore[union-attr]
[left, right], axis=1
)
else:
comb = np.concatenate([left, right], axis=1)
return comb
Expand Down
9 changes: 3 additions & 6 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def _check_timedeltalike_freq_compat(self, other):
return lib.item_from_zerodim(delta)


def raise_on_incompatible(left, right):
def raise_on_incompatible(left, right) -> IncompatibleFrequency:
"""
Helper function to render a consistent error message when raising
IncompatibleFrequency.
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def validate_dtype_freq(dtype, freq: timedelta | str | None) -> BaseOffset:


def validate_dtype_freq(
dtype, freq: BaseOffsetT | timedelta | str | None
dtype, freq: BaseOffsetT | BaseOffset | timedelta | str | None
) -> BaseOffsetT:
"""
If both a dtype and a freq are available, ensure they match. If only
Expand All @@ -1117,10 +1117,7 @@ def validate_dtype_freq(
IncompatibleFrequency : mismatch between dtype and freq
"""
if freq is not None:
# error: Incompatible types in assignment (expression has type
# "BaseOffset", variable has type "Union[BaseOffsetT, timedelta,
# str, None]")
freq = to_offset(freq) # type: ignore[assignment]
freq = to_offset(freq)

if dtype is not None:
dtype = pandas_dtype(dtype)
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,9 @@ def npoints(self) -> int:
"""
return self.sp_index.npoints

def isna(self):
# error: Return type "SparseArray" of "isna" incompatible with return type
# "ndarray[Any, Any] | ExtensionArraySupportsAnyAll" in supertype "ExtensionArray"
def isna(self) -> Self: # type: ignore[override]
# If null fill value, we want SparseDtype[bool, true]
# to preserve the same memory usage.
dtype = SparseDtype(bool, self._null_fill_value)
Expand Down Expand Up @@ -1428,7 +1430,7 @@ def all(self, axis=None, *args, **kwargs):

return values.all()

def any(self, axis: AxisInt = 0, *args, **kwargs):
def any(self, axis: AxisInt = 0, *args, **kwargs) -> bool:
"""
Tests whether at least one of elements evaluate True

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
NumpySorter,
NumpyValueArrayLike,
Scalar,
Self,
npt,
type_t,
)
Expand Down Expand Up @@ -131,7 +132,7 @@ def type(self) -> type[str]:
return str

@classmethod
def construct_from_string(cls, string):
def construct_from_string(cls, string) -> Self:
"""
Construct a StringDtype from a string.

Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
npt,
)

from pandas import Series


ArrowStringScalarOrNAT = Union[str, libmissing.NAType]

Expand Down Expand Up @@ -547,7 +549,7 @@ def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True):
def value_counts(self, dropna: bool = True) -> Series:
from pandas import Series

result = super().value_counts(dropna)
Expand Down
21 changes: 11 additions & 10 deletions pandas/core/computation/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tokenize
from typing import (
Callable,
ClassVar,
TypeVar,
)

Expand Down Expand Up @@ -349,8 +350,8 @@ class BaseExprVisitor(ast.NodeVisitor):
preparser : callable
"""

const_type: type[Term] = Constant
term_type = Term
const_type: ClassVar[type[Constant]] = Constant
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably have one PR just focused on making ClassVars more explicit (type checkers have difficulties determining whether they are class or instance variables).

term_type: ClassVar[type[Term]] = Term

binary_ops = CMP_OPS_SYMS + BOOL_OPS_SYMS + ARITH_OPS_SYMS
binary_op_nodes = (
Expand Down Expand Up @@ -540,26 +541,26 @@ def visit_UnaryOp(self, node, **kwargs):
operand = self.visit(node.operand)
return op(operand)

def visit_Name(self, node, **kwargs):
def visit_Name(self, node, **kwargs) -> Term:
return self.term_type(node.id, self.env, **kwargs)

# TODO(py314): deprecated since Python 3.8. Remove after Python 3.14 is min
def visit_NameConstant(self, node, **kwargs) -> Term:
def visit_NameConstant(self, node, **kwargs) -> Constant:
return self.const_type(node.value, self.env)

# TODO(py314): deprecated since Python 3.8. Remove after Python 3.14 is min
def visit_Num(self, node, **kwargs) -> Term:
def visit_Num(self, node, **kwargs) -> Constant:
return self.const_type(node.value, self.env)

def visit_Constant(self, node, **kwargs) -> Term:
def visit_Constant(self, node, **kwargs) -> Constant:
return self.const_type(node.value, self.env)

# TODO(py314): deprecated since Python 3.8. Remove after Python 3.14 is min
def visit_Str(self, node, **kwargs):
def visit_Str(self, node, **kwargs) -> Term:
name = self.env.add_tmp(node.s)
return self.term_type(name, self.env)

def visit_List(self, node, **kwargs):
def visit_List(self, node, **kwargs) -> Term:
name = self.env.add_tmp([self.visit(e)(self.env) for e in node.elts])
return self.term_type(name, self.env)

Expand All @@ -569,7 +570,7 @@ def visit_Index(self, node, **kwargs):
"""df.index[4]"""
return self.visit(node.value)

def visit_Subscript(self, node, **kwargs):
def visit_Subscript(self, node, **kwargs) -> Term:
from pandas import eval as pd_eval

value = self.visit(node.value)
Expand All @@ -589,7 +590,7 @@ def visit_Subscript(self, node, **kwargs):
name = self.env.add_tmp(v)
return self.term_type(name, env=self.env)

def visit_Slice(self, node, **kwargs):
def visit_Slice(self, node, **kwargs) -> slice:
"""df.index[slice(4,6)]"""
lower = node.lower
if lower is not None:
Expand Down
Loading