Skip to content

TYP: __getitem__ method of EA (2nd pass) #37921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import numpy as np

from pandas._libs.missing import NAType # noqa: F401

# To prevent import cycles place any internal imports in the branch below
# and use a string literal forward reference to it in subsequent types
# https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
Expand Down
17 changes: 15 additions & 2 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Optional, Sequence, Type, TypeVar, Union, overload

import numpy as np

Expand All @@ -25,6 +25,7 @@
NDArrayBackedExtensionArrayT = TypeVar(
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
)
EAScalarOrMissing = object # both scalar value and na_value can be any type


class NDArrayBackedExtensionArray(ExtensionArray):
Expand Down Expand Up @@ -214,9 +215,21 @@ def __setitem__(self, key, value):
def _validate_setitem_value(self, value):
return value

@overload
# error: Overloaded function signatures 1 and 2 overlap with incompatible
# return types [misc]
def __getitem__(self, key: int) -> EAScalarOrMissing: # type: ignore[misc]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can also return NDArrayBackedExtensionArrayT

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate. is this for 2d EA?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NDArrayBackedExtensionArray supports 2D, exactly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it also supports a tuple indexer? and I guess from that NDArrayBackedExtensionArray can't support nested data as it uses is_scalar checks.

so also need to change EAScalarOrMissing = object # both scalar value and na_value can be any type -> ScalarOrScalarMissing = Scalar # both values and na_value must be scalars ?

Copy link
Member

@jbrockmendel jbrockmendel Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it also supports a tuple indexer?

Yes. For that matter i think even 1D will support 1-tuples

and I guess from that NDArrayBackedExtensionArray can't support nested data as it uses is_scalar checks.

PandasArray can have object dtype, and Categorical can hold tuples. We never see 2D versions of those in practice though (yet).

...

@overload
def __getitem__(
self: NDArrayBackedExtensionArrayT, key: Union[slice, np.ndarray]
) -> NDArrayBackedExtensionArrayT:
...

def __getitem__(
self: NDArrayBackedExtensionArrayT, key: Union[int, slice, np.ndarray]
) -> Union[NDArrayBackedExtensionArrayT, Any]:
) -> Union[NDArrayBackedExtensionArrayT, EAScalarOrMissing]:
if lib.is_integer(key):
# fast-path
result = self._ndarray[key]
Expand Down
14 changes: 13 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypeVar,
Union,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -51,6 +52,7 @@
_extension_array_shared_docs: Dict[str, str] = dict()

ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray")
EAScalarOrMissing = object # both scalar value and na_value can be any type


class ExtensionArray:
Expand Down Expand Up @@ -256,9 +258,19 @@ def _from_factorized(cls, values, original):
# Must be a Sequence
# ------------------------------------------------------------------------

@overload
# error: Overloaded function signatures 1 and 2 overlap with incompatible
# return types [misc]
def __getitem__(self, item: int) -> EAScalarOrMissing: # type: ignore[misc]
...

@overload
def __getitem__(self, item: Union[slice, np.ndarray]) -> ExtensionArray:
...

def __getitem__(
self, item: Union[int, slice, np.ndarray]
) -> Union[ExtensionArray, Any]:
) -> Union[ExtensionArray, EAScalarOrMissing]:
"""
Select a subset of self.

Expand Down
15 changes: 13 additions & 2 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TypeVar,
Union,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -266,9 +267,19 @@ def __array__(self, dtype=None) -> np.ndarray:
return np.array(list(self), dtype=object)
return self._ndarray

@overload
def __getitem__(self, key: int) -> DTScalarOrNaT:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, can still return DatetimeLikeArrayT

...

@overload
def __getitem__(
self: DatetimeLikeArrayT, key: Union[slice, np.ndarray]
) -> DatetimeLikeArrayT:
...

def __getitem__(
self, key: Union[int, slice, np.ndarray]
) -> Union[DatetimeLikeArrayMixin, DTScalarOrNaT]:
self: DatetimeLikeArrayT, key: Union[int, slice, np.ndarray]
) -> Union[DatetimeLikeArrayT, DTScalarOrNaT]:
"""
This getitem defers to the underlying array, which by-definition can
only handle list-likes, slices, and integer scalars
Expand Down
8 changes: 3 additions & 5 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, time, timedelta, tzinfo
from typing import Optional, Union, cast
from typing import Optional, Union
import warnings

import numpy as np
Expand Down Expand Up @@ -444,11 +444,9 @@ def _generate_range(
)

if not left_closed and len(index) and index[0] == start:
# TODO: overload DatetimeLikeArrayMixin.__getitem__
index = cast(DatetimeArray, index[1:])
index = index[1:]
if not right_closed and len(index) and index[-1] == end:
# TODO: overload DatetimeLikeArrayMixin.__getitem__
index = cast(DatetimeArray, index[:-1])
index = index[:-1]

dtype = tz_to_dtype(tz)
return cls._simple_new(index.asi8, freq=freq, dtype=dtype)
Expand Down
27 changes: 24 additions & 3 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)

import numpy as np

from pandas._libs import lib, missing as libmissing
from pandas._typing import Scalar
from pandas._typing import NAType, Scalar
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly, doc

Expand All @@ -30,6 +39,8 @@


BaseMaskedArrayT = TypeVar("BaseMaskedArrayT", bound="BaseMaskedArray")
# scalar value is a Python scalar, missing value is pd.NA
ScalarOrNAType = Union[Scalar, NAType]


class BaseMaskedDtype(ExtensionDtype):
Expand Down Expand Up @@ -102,9 +113,19 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
def dtype(self) -> BaseMaskedDtype:
raise AbstractMethodError(self)

@overload
# error: Overloaded function signatures 1 and 2 overlap with incompatible return
# types [misc]
def __getitem__(self, item: int) -> ScalarOrNAType: # type: ignore[misc]
...

@overload
def __getitem__(self, item: Union[slice, np.ndarray]) -> BaseMaskedArray:
...

def __getitem__(
self, item: Union[int, slice, np.ndarray]
) -> Union[BaseMaskedArray, Any]:
) -> Union[BaseMaskedArray, ScalarOrNAType]:
if is_integer(item):
if self._mask[item]:
return self.dtype.na_value
Expand Down
18 changes: 5 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3222,14 +3222,8 @@ def _get_nearest_indexer(self, target: "Index", limit, tolerance) -> np.ndarray:
right_indexer = self.get_indexer(target, "backfill", limit=limit)

target_values = target._values
# error: Unsupported left operand type for - ("ExtensionArray")
left_distances = np.abs(
self._values[left_indexer] - target_values # type: ignore[operator]
)
# error: Unsupported left operand type for - ("ExtensionArray")
right_distances = np.abs(
self._values[right_indexer] - target_values # type: ignore[operator]
)
left_distances = np.abs(self._values[left_indexer] - target_values)
right_distances = np.abs(self._values[right_indexer] - target_values)

op = operator.lt if self.is_monotonic_increasing else operator.le
indexer = np.where(
Expand All @@ -3248,8 +3242,7 @@ def _filter_indexer_tolerance(
indexer: np.ndarray,
tolerance,
) -> np.ndarray:
# error: Unsupported left operand type for - ("ExtensionArray")
distance = abs(self._values[indexer] - target) # type: ignore[operator]
distance = abs(self._values[indexer] - target)
indexer = np.where(distance <= tolerance, indexer, -1)
return indexer

Expand Down Expand Up @@ -4546,9 +4539,8 @@ def asof_locs(self, where: "Index", mask) -> np.ndarray:

result = np.arange(len(self))[mask].take(locs)

# TODO: overload return type of ExtensionArray.__getitem__
first_value = cast(Any, self._values[mask.argmax()])
result[(locs == 0) & (where._values < first_value)] = -1
first = mask.argmax()
result[(locs == 0) & (where._values < self._values[first])] = -1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverting change from previous PR


return result

Expand Down
7 changes: 2 additions & 5 deletions pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import Any, cast
from typing import Any

import numpy as np

Expand Down Expand Up @@ -673,10 +673,7 @@ def difference(self, other, sort=None):

if self.equals(other):
# pass an empty PeriodArray with the appropriate dtype

# TODO: overload DatetimeLikeArrayMixin.__getitem__
values = cast(PeriodArray, self._data[:0])
return type(self)._simple_new(values, name=self.name)
return type(self)._simple_new(self._data[:0], name=self.name)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverting change from previous PR


if is_object_dtype(other):
return self.astype(object).difference(other).astype(self.dtype)
Expand Down
1 change: 1 addition & 0 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def split_and_operate(
-------
list of blocks
"""
assert isinstance(self.values, np.ndarray)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to ensure split_and_operate is not called from EA block (or base method not overridden)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move assert after ndim check

if mask is None:
mask = np.broadcast_to(True, shape=self.shape)

Expand Down