Skip to content

TYP: Add cast to ABC classes. #37902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from pandas.core.indexers import validate_indices

if TYPE_CHECKING:
from pandas import Categorical, DataFrame, Series
from pandas import Categorical, DataFrame, Index, Series

_shared_docs: Dict[str, str] = {}

Expand Down Expand Up @@ -533,7 +533,7 @@ def factorize(
sort: bool = False,
na_sentinel: Optional[int] = -1,
size_hint: Optional[int] = None,
) -> Tuple[np.ndarray, Union[np.ndarray, ABCIndex]]:
) -> Tuple[np.ndarray, Union[np.ndarray, "Index"]]:
"""
Encode the object as an enumerated type or categorical variable.

Expand Down
13 changes: 7 additions & 6 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,14 @@ def __getitem__(self, key):
return self._gotitem(list(key), ndim=2)

elif not getattr(self, "as_index", False):
if key not in self.obj.columns:
# error: "SelectionMixin" has no attribute "obj" [attr-defined]
if key not in self.obj.columns: # type: ignore[attr-defined]
raise KeyError(f"Column not found: {key}")
return self._gotitem(key, ndim=2)

else:
if key not in self.obj:
# error: "SelectionMixin" has no attribute "obj" [attr-defined]
if key not in self.obj: # type: ignore[attr-defined]
raise KeyError(f"Column not found: {key}")
return self._gotitem(key, ndim=1)

Expand Down Expand Up @@ -919,10 +921,9 @@ def _map_values(self, mapper, na_action=None):
# "astype" [attr-defined]
values = self.astype(object)._values # type: ignore[attr-defined]
if na_action == "ignore":

def map_f(values, f):
return lib.map_infer_mask(values, f, isna(values).view(np.uint8))

map_f = lambda values, f: lib.map_infer_mask(
values, f, isna(values).view(np.uint8)
)
elif na_action is None:
map_f = lib.map_infer
else:
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ def convert_to_list_like(
inputs are returned unmodified whereas others are converted to list.
"""
if isinstance(values, (list, np.ndarray, ABCIndex, ABCSeries, ABCExtensionArray)):
return values
# np.ndarray resolving as Any gives a false positive
return values # type: ignore[return-value]
elif isinstance(values, abc.Iterable) and not isinstance(values, str):
return list(values)

Expand Down
13 changes: 8 additions & 5 deletions pandas/core/computation/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from functools import partial, wraps
from typing import Dict, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Type, Union
import warnings

import numpy as np
Expand All @@ -17,13 +17,16 @@
import pandas.core.common as com
from pandas.core.computation.common import result_type_many

if TYPE_CHECKING:
from pandas.core.indexes.api import Index


def _align_core_single_unary_op(
term,
) -> Tuple[Union[partial, Type[FrameOrSeries]], Optional[Dict[str, int]]]:
) -> Tuple[Union[partial, Type[FrameOrSeries]], Optional[Dict[str, "Index"]]]:

typ: Union[partial, Type[FrameOrSeries]]
axes: Optional[Dict[str, int]] = None
axes: Optional[Dict[str, "Index"]] = None

if isinstance(term.value, np.ndarray):
typ = partial(np.asanyarray, dtype=term.value.dtype)
Expand All @@ -36,8 +39,8 @@ def _align_core_single_unary_op(


def _zip_axes_from_type(
typ: Type[FrameOrSeries], new_axes: Sequence[int]
) -> Dict[str, int]:
typ: Type[FrameOrSeries], new_axes: Sequence["Index"]
) -> Dict[str, "Index"]:
return {name: new_axes[i] for i, name in enumerate(typ._AXIS_ORDERS)}


Expand Down
8 changes: 5 additions & 3 deletions pandas/core/computation/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import tokenize
from typing import Iterator, Tuple

from pandas._typing import Label

# A token value Python's tokenizer probably will never use.
BACKTICK_QUOTED_STRING = 100

Expand Down Expand Up @@ -91,7 +93,7 @@ def clean_backtick_quoted_toks(tok: Tuple[int, str]) -> Tuple[int, str]:
return toknum, tokval


def clean_column_name(name: str) -> str:
def clean_column_name(name: "Label") -> "Label":
"""
Function to emulate the cleaning of a backtick quoted name.

Expand All @@ -102,12 +104,12 @@ def clean_column_name(name: str) -> str:

Parameters
----------
name : str
name : Label
Name to be cleaned.

Returns
-------
name : str
name : Label
Returns the name after tokenizing and cleaning.

Notes
Expand Down
16 changes: 9 additions & 7 deletions pandas/core/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def array(
return result


def extract_array(obj: AnyArrayLike, extract_numpy: bool = False) -> ArrayLike:
def extract_array(
obj: AnyArrayLike, extract_numpy: bool = False
) -> Union[ExtensionArray, np.ndarray]:
"""
Extract the ndarray or ExtensionArray from a Series or Index.

Expand Down Expand Up @@ -394,14 +396,14 @@ def extract_array(obj: AnyArrayLike, extract_numpy: bool = False) -> ArrayLike:
array([1, 2, 3])
"""
if isinstance(obj, (ABCIndexClass, ABCSeries)):
obj = obj.array
result = obj.array
else:
result = obj

if extract_numpy and isinstance(obj, ABCPandasArray):
obj = obj.to_numpy()
if extract_numpy and isinstance(result, ABCPandasArray):
result = result.to_numpy()

# error: Incompatible return value type (got "Index", expected "ExtensionArray")
# error: Incompatible return value type (got "Series", expected "ExtensionArray")
return obj # type: ignore[return-value]
return result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem here is that AnyArrayLike is a typevar and also the type of obj is wrong anyway and should be object I think could also change obj to object/Any for now until we overload and keep the implementation unchanged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you revert the changes to the implementation.

Copy link
Member Author

@rhshadrach rhshadrach Nov 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but would require adding obj = obj.array # type: ignore[assignment] or possibly changing typing on arguments (not return type)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is a similar issue to #37902 (comment) where we changed FrameOrSeries to FrameOrSeriesUnion.

We don't yet have an AnyArrayLikeUnion to do the same here.

because AnyArrayLike is a typevar, the type annotations imply that the type doesn't change.

This is not true. Only for np.ndarray and ExtensionArray if extract_array=False.

so maybe the signature should be

def extract_array(
    obj: Union[object, ArrayLike], extract_numpy: bool = False
) -> Union[object, ArrayLike]:

but we also need to overload (to account for Series and Index array extraction) but while numpy types resolve to Any we get mypy errors about overlapping overloads.

so for now, i think ok to do...

def extract_array(
    obj: object, extract_numpy: bool = False
) -> Union[Any, ArrayLike]:

It's the use of the typevar that causes the issue, so I'm not too concerned what type annotations are chosen at this stage, so long as we are not changing the implementation because we are using a TypeVar incorrectly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Changes made.



def sanitize_array(
Expand Down
19 changes: 16 additions & 3 deletions pandas/core/dtypes/generic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
""" define generic base classes for pandas objects """

from typing import TYPE_CHECKING, Type, cast

if TYPE_CHECKING:
from pandas import DataFrame, NDFrame, Series


# define abstract base classes to enable isinstance type checking on our
# objects
Expand Down Expand Up @@ -53,9 +58,17 @@ def _check(cls, inst) -> bool:
},
)

ABCNDFrame = create_pandas_abc_type("ABCNDFrame", "_typ", ("series", "dataframe"))
ABCSeries = create_pandas_abc_type("ABCSeries", "_typ", ("series",))
ABCDataFrame = create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe",))
ABCNDFrame = cast(
"Type[NDFrame]",
create_pandas_abc_type("ABCNDFrame", "_typ", ("series", "dataframe")),
)
ABCSeries = cast(
"Type[Series]",
create_pandas_abc_type("ABCSeries", "_typ", ("series",)),
)
ABCDataFrame = cast(
"Type[DataFrame]", create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe",))
)

ABCCategorical = create_pandas_abc_type("ABCCategorical", "_typ", ("categorical"))
ABCDatetimeArray = create_pandas_abc_type("ABCDatetimeArray", "_typ", ("datetimearray"))
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _get_axis_resolvers(self, axis: str) -> Dict[str, Union[Series, MultiIndex]]
return d

@final
def _get_index_resolvers(self) -> Dict[str, Union[Series, MultiIndex]]:
def _get_index_resolvers(self) -> Dict[Label, Union[Series, MultiIndex]]:
from pandas.core.computation.parsing import clean_column_name

d: Dict[str, Union[Series, MultiIndex]] = {}
Expand All @@ -521,7 +521,7 @@ def _get_index_resolvers(self) -> Dict[str, Union[Series, MultiIndex]]:
return {clean_column_name(k): v for k, v in d.items() if not isinstance(k, int)}

@final
def _get_cleaned_column_resolvers(self) -> Dict[str, ABCSeries]:
def _get_cleaned_column_resolvers(self) -> Dict[Label, Series]:
"""
Return the special character free column resolvers of a dataframe.

Expand All @@ -532,7 +532,6 @@ def _get_cleaned_column_resolvers(self) -> Dict[str, ABCSeries]:
from pandas.core.computation.parsing import clean_column_name

if isinstance(self, ABCSeries):
self = cast("Series", self)
return {clean_column_name(self.name): self}

return {
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,7 @@ def ravel(i):

raise ValueError("Incompatible indexer with Series")

def _align_frame(self, indexer, df: ABCDataFrame):
def _align_frame(self, indexer, df: "DataFrame"):
is_frame = self.ndim == 2

if isinstance(indexer, tuple):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def extract_index(data) -> Index:
index = Index([])
elif len(data) > 0:
raw_lengths = []
indexes = []
indexes: List[Any] = []

have_raw_arrays = False
have_series = False
Expand Down
20 changes: 16 additions & 4 deletions pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,21 @@
"""

from collections import abc
from typing import TYPE_CHECKING, Iterable, List, Mapping, Type, Union, cast, overload
from typing import (
TYPE_CHECKING,
Iterable,
List,
Mapping,
Optional,
Type,
Union,
cast,
overload,
)

import numpy as np

from pandas._typing import FrameOrSeries, FrameOrSeriesUnion, Label
from pandas._typing import FrameOrSeriesUnion, Label

from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
Expand Down Expand Up @@ -295,7 +305,7 @@ class _Concatenator:

def __init__(
self,
objs: Union[Iterable[FrameOrSeries], Mapping[Label, FrameOrSeries]],
objs: Union[Iterable["NDFrame"], Mapping[Label, "NDFrame"]],
axis=0,
join: str = "outer",
keys=None,
Expand Down Expand Up @@ -366,7 +376,7 @@ def __init__(
# get the sample
# want the highest ndim that we have, and must be non-empty
# unless all objs are empty
sample = None
sample: Optional["NDFrame"] = None
if len(ndims) > 1:
max_ndim = max(ndims)
for obj in objs:
Expand Down Expand Up @@ -436,6 +446,8 @@ def __init__(
# to line up
if self._is_frame and axis == 1:
name = 0
# mypy needs to know sample is not an NDFrame
sample = cast("FrameOrSeriesUnion", sample)
obj = sample._constructor({name: obj})

self.objs.append(obj)
Expand Down
3 changes: 1 addition & 2 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,10 @@ def __init__(self, data):
array = data.array
self._array = array

self._index = self._name = None
if isinstance(data, ABCSeries):
self._index = data.index
self._name = data.name
else:
self._index = self._name = None

# ._values.categories works for both Series/Index
self._parent = data._values.categories if self._is_categorical else data
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/window/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Common utility functions for rolling operations"""
from collections import defaultdict
from typing import cast
import warnings

import numpy as np
Expand Down Expand Up @@ -109,6 +110,9 @@ def dataframe_from_int_dict(data, frame_template):

# set the index and reorder
if arg2.columns.nlevels > 1:
# mypy needs to know columns is a MultiIndex, Index doesn't
# have levels attribute
arg2.columns = cast(MultiIndex, arg2.columns)
result.index = MultiIndex.from_product(
arg2.columns.levels + [result_index]
)
Expand Down
6 changes: 4 additions & 2 deletions pandas/plotting/_matplotlib/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import matplotlib.ticker as ticker
import numpy as np

from pandas._typing import FrameOrSeries
from pandas._typing import FrameOrSeriesUnion

from pandas.core.dtypes.common import is_list_like
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
Expand All @@ -30,7 +30,9 @@ def format_date_labels(ax: "Axes", rot):
fig.subplots_adjust(bottom=0.2)


def table(ax, data: FrameOrSeries, rowLabels=None, colLabels=None, **kwargs) -> "Table":
def table(
ax, data: FrameOrSeriesUnion, rowLabels=None, colLabels=None, **kwargs
) -> "Table":
if isinstance(data, ABCSeries):
data = data.to_frame()
elif isinstance(data, ABCDataFrame):
Expand Down