Skip to content

TYP: Update type hints for ExtensionArray and ExtensionDtype #39501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f2c52a4
small typing fixes
Dr-Irv Jan 23, 2021
d7ff8d3
fix ExtensionArray and EXtensionDtype
Dr-Irv Jan 23, 2021
49fa06e
merge with master
Dr-Irv Jan 31, 2021
03b2c4a
fixes for delete, isin, unique
Dr-Irv Jan 31, 2021
3e19958
fix import of Literal
Dr-Irv Jan 31, 2021
6861901
remove quotes on ExtensionDType.construct_from_string
Dr-Irv Jan 31, 2021
9be6486
move numpy workaround to _typing.py
Dr-Irv Feb 1, 2021
260b367
remove numpy dummy
Dr-Irv Feb 2, 2021
6276725
remove extra line in _typing
Dr-Irv Feb 2, 2021
4dafaca
Merge remote-tracking branch 'upstream/master' into extensiontyping
Dr-Irv Feb 3, 2021
8b2cee2
import Literal
Dr-Irv Feb 3, 2021
3a7d839
Merge remote-tracking branch 'upstream/master' into extensiontyping
Dr-Irv Feb 14, 2021
a21bb60
merge with master
Dr-Irv Mar 8, 2021
8cd6b76
isort precommit fix
Dr-Irv Mar 8, 2021
e0e0131
fix interval.repeat() typing
Dr-Irv Mar 8, 2021
6a6a21f
overload for __getitem__ and use pattern with ExtensionArrayT as self…
Dr-Irv Mar 9, 2021
bf753e6
lose less ExtensionArrayT. Make registry private. consolidate overload
Dr-Irv Mar 10, 2021
c9795a5
remove ExtensionArray typing of self
Dr-Irv Mar 10, 2021
d452842
Merge remote-tracking branch 'upstream/master' into extensiontyping
Dr-Irv Mar 10, 2021
3c2c78b
merge with upstream/master
Dr-Irv Mar 12, 2021
548c198
make extension arrays work with new typing, fixing astype and to_numpy
Dr-Irv Mar 12, 2021
db8ed9b
fix Literal import
Dr-Irv Mar 12, 2021
f8191f8
fix logic in ensure_int_or_float
Dr-Irv Mar 12, 2021
575645f
fix conflict with master
Dr-Irv Mar 12, 2021
6f8fcb5
fix typing in groupby to_numpy call
Dr-Irv Mar 12, 2021
3ea2420
fix groupby again. Allow kwargs for extension to_numpy
Dr-Irv Mar 13, 2021
c83a628
Merge remote-tracking branch 'upstream/master' into extensiontyping
simonjayhawkins Mar 13, 2021
5bb24d4
fixes for merge with master
Dr-Irv Mar 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

import operator
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -49,6 +51,9 @@
from pandas.core.missing import get_fill_func
from pandas.core.sorting import nargminmax, nargsort

if TYPE_CHECKING:
from typing import Literal

_extension_array_shared_docs: Dict[str, str] = {}

ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray")
Expand Down Expand Up @@ -348,7 +353,7 @@ def __len__(self) -> int:
"""
raise AbstractMethodError(self)

def __iter__(self):
def __iter__(self) -> Iterator[Any]:
"""
Iterate over elements of the array.
"""
Expand All @@ -358,7 +363,7 @@ def __iter__(self):
for i in range(len(self)):
yield self[i]

def __contains__(self, item) -> bool:
def __contains__(self, item: Any) -> bool:
"""
Return for `item in self`.
"""
Expand Down Expand Up @@ -397,7 +402,7 @@ def to_numpy(
self,
dtype: Optional[Dtype] = None,
copy: bool = False,
na_value=lib.no_default,
na_value: Optional[Any] = lib.no_default,
) -> np.ndarray:
"""
Convert to a NumPy ndarray.
Expand Down Expand Up @@ -476,7 +481,7 @@ def nbytes(self) -> int:
# Additional Methods
# ------------------------------------------------------------------------

def astype(self, dtype, copy=True):
def astype(self, dtype: Dtype, copy: bool = True) -> np.ndarray:
"""
Cast to a NumPy array with 'dtype'.

Expand Down Expand Up @@ -556,8 +561,8 @@ def argsort(
ascending: bool = True,
kind: str = "quicksort",
na_position: str = "last",
*args,
**kwargs,
*args: Any,
**kwargs: Any,
) -> np.ndarray:
"""
Return the indices that would sort this array.
Expand Down Expand Up @@ -645,7 +650,12 @@ def argmax(self, skipna: bool = True) -> int:
raise NotImplementedError
return nargminmax(self, "argmax")

def fillna(self, value=None, method=None, limit=None):
def fillna(
self,
value: Optional[Union[Any, ArrayLike]] = None,
method: Optional[Literal["backfill", "bfill", "ffill", "pad"]] = None,
limit: Optional[int] = None,
) -> ExtensionArray:
"""
Fill NA/NaN values using the specified method.

Expand Down Expand Up @@ -697,7 +707,7 @@ def fillna(self, value=None, method=None, limit=None):
new_values = self.copy()
return new_values

def dropna(self):
def dropna(self) -> ExtensionArray:
"""
Return ExtensionArray without NA values.

Expand Down Expand Up @@ -761,7 +771,7 @@ def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
b = empty
return self._concat_same_type([a, b])

def unique(self):
def unique(self) -> ExtensionArray:
"""
Compute the ExtensionArray of unique values.

Expand All @@ -772,7 +782,12 @@ def unique(self):
uniques = unique(self.astype(object))
return self._from_sequence(uniques, dtype=self.dtype)

def searchsorted(self, value, side="left", sorter=None):
def searchsorted(
self,
value: ArrayLike,
side: Optional[Literal["left", "right"]] = "left",
sorter: Optional[ArrayLike] = None,
) -> np.ndarray:
"""
Find indices where elements should be inserted to maintain order.

Expand Down Expand Up @@ -853,7 +868,7 @@ def equals(self, other: object) -> bool:
equal_na = self.isna() & other.isna()
return bool((equal_values | equal_na).all())

def isin(self, values) -> np.ndarray:
def isin(self, values: Union[ExtensionArray, Sequence[Any]]) -> np.ndarray:
"""
Pointwise comparison for set containment in the given values.

Expand Down Expand Up @@ -987,7 +1002,9 @@ def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, ExtensionArray]:

@Substitution(klass="ExtensionArray")
@Appender(_extension_array_shared_docs["repeat"])
def repeat(self, repeats, axis=None):
def repeat(
self, repeats: Union[int, Sequence[int]], axis: Literal[None] = None
) -> ExtensionArray:
nv.validate_repeat((), {"axis": axis})
ind = np.arange(len(self)).repeat(repeats)
return self.take(ind)
Expand Down Expand Up @@ -1171,7 +1188,7 @@ def _formatter(self, boxed: bool = False) -> Callable[[Any], Optional[str]]:
# Reshaping
# ------------------------------------------------------------------------

def transpose(self, *axes) -> ExtensionArray:
def transpose(self, *axes: int) -> ExtensionArray:
"""
Return a transposed view on this array.

Expand All @@ -1184,7 +1201,9 @@ def transpose(self, *axes) -> ExtensionArray:
def T(self) -> ExtensionArray:
return self.transpose()

def ravel(self, order="C") -> ExtensionArray:
def ravel(
self, order: Optional[Literal["C", "F", "A", "K"]] = "C"
) -> ExtensionArray:
"""
Return a flattened view on this array.

Expand Down Expand Up @@ -1258,13 +1277,15 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
"""
raise TypeError(f"cannot perform {name} with type {self.dtype}")

def __hash__(self):
def __hash__(self) -> int:
raise TypeError(f"unhashable type: {repr(type(self).__name__)}")

# ------------------------------------------------------------------------
# Non-Optimized Default Methods

def delete(self: ExtensionArrayT, loc) -> ExtensionArrayT:
def delete(
self: ExtensionArrayT, loc: Union[int, Sequence[int]]
) -> ExtensionArrayT:
indexer = np.delete(np.arange(len(self)), loc)
return self.take(indexer)

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import abc
import numbers
import operator
from typing import Any, Callable, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Callable, Optional, Sequence, Type, TypeVar, Union, cast
import warnings

import numpy as np
Expand Down Expand Up @@ -1174,7 +1174,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
if skipna:
arr = self
else:
arr = self.dropna()
arr = cast(SparseArray, self.dropna())

# we don't support these kwargs.
# They should only be present when called via pandas, so do it here.
Expand Down
12 changes: 5 additions & 7 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, cast

import numpy as np

Expand Down Expand Up @@ -199,7 +199,7 @@ def construct_array_type(cls) -> Type[ExtensionArray]:
raise NotImplementedError

@classmethod
def construct_from_string(cls, string: str):
def construct_from_string(cls, string: str) -> ExtensionDtype:
r"""
Construct this type from a string.

Expand Down Expand Up @@ -410,9 +410,7 @@ def register(self, dtype: Type[ExtensionDtype]) -> None:

self.dtypes.append(dtype)

def find(
self, dtype: Union[Type[ExtensionDtype], str]
) -> Optional[Type[ExtensionDtype]]:
def find(self, dtype: Union[Type[ExtensionDtype], str]) -> Optional[ExtensionDtype]:
"""
Parameters
----------
Expand All @@ -427,7 +425,7 @@ def find(
if not isinstance(dtype, type):
dtype_type = type(dtype)
if issubclass(dtype_type, ExtensionDtype):
return dtype
return cast(ExtensionDtype, dtype)

return None

Expand All @@ -440,4 +438,4 @@ def find(
return None


registry = Registry()
registry: Registry = Registry()