Skip to content

TYP: Use Self for type checking (remaining locations) #51524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Hashable,
Iterator,
Literal,
TypeVar,
cast,
final,
overload,
Expand All @@ -29,6 +28,7 @@
DtypeObj,
IndexLabel,
NDFrameT,
Self,
Shape,
npt,
)
Expand Down Expand Up @@ -91,8 +91,6 @@
"duplicated": "IndexOpsMixin",
}

_T = TypeVar("_T", bound="IndexOpsMixin")


class PandasObject(DirNamesMixin):
"""
Expand Down Expand Up @@ -285,7 +283,7 @@ def _values(self) -> ExtensionArray | np.ndarray:
raise AbstractMethodError(self)

@final
def transpose(self: _T, *args, **kwargs) -> _T:
def transpose(self, *args, **kwargs) -> Self:
"""
Return the transpose, which is by definition self.

Expand Down
5 changes: 2 additions & 3 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
if TYPE_CHECKING:
from pandas._typing import (
DtypeObj,
Self,
Shape,
npt,
type_t,
Expand Down Expand Up @@ -228,9 +229,7 @@ def empty(self, shape: Shape) -> type_t[ExtensionArray]:
return cls._empty(shape, dtype=self)

@classmethod
def construct_from_string(
cls: type_t[ExtensionDtypeT], string: str
) -> ExtensionDtypeT:
def construct_from_string(cls, string: str) -> Self:
r"""
Construct this type from a string.

Expand Down
8 changes: 2 additions & 6 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
TYPE_CHECKING,
Hashable,
Sequence,
TypeVar,
cast,
final,
)
Expand Down Expand Up @@ -78,15 +77,14 @@
from pandas._typing import (
Axis,
AxisInt,
Self,
)

from pandas import (
DataFrame,
Series,
)

_LocationIndexerT = TypeVar("_LocationIndexerT", bound="_LocationIndexer")

# "null slice"
_NS = slice(None, None)
_one_ellipsis_message = "indexer may only contain one '...' entry"
Expand Down Expand Up @@ -669,9 +667,7 @@ class _LocationIndexer(NDFrameIndexerBase):
_takeable: bool

@final
def __call__(
self: _LocationIndexerT, axis: Axis | None = None
) -> _LocationIndexerT:
def __call__(self, axis: Axis | None = None) -> Self:
# we need to return a copy of ourselves
new_self = type(self)(self.name, self.obj)

Expand Down
67 changes: 33 additions & 34 deletions pandas/core/internals/array_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Callable,
Hashable,
Literal,
TypeVar,
)

import numpy as np
Expand Down Expand Up @@ -93,9 +92,9 @@
AxisInt,
DtypeObj,
QuantileInterpolation,
Self,
npt,
)
T = TypeVar("T", bound="BaseArrayManager")


class BaseArrayManager(DataManager):
Expand Down Expand Up @@ -131,7 +130,7 @@ def __init__(
) -> None:
raise NotImplementedError

def make_empty(self: T, axes=None) -> T:
def make_empty(self, axes=None) -> Self:
"""Return an empty ArrayManager with the items axis of len 0 (no columns)"""
if axes is None:
axes = [self.axes[1:], Index([])]
Expand Down Expand Up @@ -195,11 +194,11 @@ def __repr__(self) -> str:
return output

def apply(
self: T,
self,
f,
align_keys: list[str] | None = None,
**kwargs,
) -> T:
) -> Self:
"""
Iterate over the arrays, collect and create a new ArrayManager.

Expand Down Expand Up @@ -257,8 +256,8 @@ def apply(
return type(self)(result_arrays, new_axes) # type: ignore[arg-type]

def apply_with_block(
self: T, f, align_keys=None, swap_axis: bool = True, **kwargs
) -> T:
self, f, align_keys=None, swap_axis: bool = True, **kwargs
) -> Self:
# switch axis to follow BlockManager logic
if swap_axis and "axis" in kwargs and self.ndim == 2:
kwargs["axis"] = 1 if kwargs["axis"] == 0 else 0
Expand Down Expand Up @@ -311,7 +310,7 @@ def apply_with_block(

return type(self)(result_arrays, self._axes)

def where(self: T, other, cond, align: bool) -> T:
def where(self, other, cond, align: bool) -> Self:
if align:
align_keys = ["other", "cond"]
else:
Expand All @@ -325,13 +324,13 @@ def where(self: T, other, cond, align: bool) -> T:
cond=cond,
)

def round(self: T, decimals: int, using_cow: bool = False) -> T:
def round(self, decimals: int, using_cow: bool = False) -> Self:
return self.apply_with_block("round", decimals=decimals, using_cow=using_cow)

def setitem(self: T, indexer, value) -> T:
def setitem(self, indexer, value) -> Self:
return self.apply_with_block("setitem", indexer=indexer, value=value)

def putmask(self: T, mask, new, align: bool = True) -> T:
def putmask(self, mask, new, align: bool = True) -> Self:
if align:
align_keys = ["new", "mask"]
else:
Expand All @@ -345,14 +344,14 @@ def putmask(self: T, mask, new, align: bool = True) -> T:
new=new,
)

def diff(self: T, n: int, axis: AxisInt) -> T:
def diff(self, n: int, axis: AxisInt) -> Self:
assert self.ndim == 2 and axis == 0 # caller ensures
return self.apply(algos.diff, n=n, axis=axis)

def interpolate(self: T, **kwargs) -> T:
def interpolate(self, **kwargs) -> Self:
return self.apply_with_block("interpolate", swap_axis=False, **kwargs)

def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
if fill_value is lib.no_default:
fill_value = None

Expand All @@ -364,7 +363,7 @@ def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
"shift", periods=periods, axis=axis, fill_value=fill_value
)

def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
def fillna(self, value, limit, inplace: bool, downcast) -> Self:
if limit is not None:
# Do this validation even if we go through one of the no-op paths
limit = libalgos.validate_limit(None, limit=limit)
Expand All @@ -373,13 +372,13 @@ def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
"fillna", value=value, limit=limit, inplace=inplace, downcast=downcast
)

def astype(self: T, dtype, copy: bool | None = False, errors: str = "raise") -> T:
def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
if copy is None:
copy = True

return self.apply(astype_array_safe, dtype=dtype, copy=copy, errors=errors)

def convert(self: T, copy: bool | None) -> T:
def convert(self, copy: bool | None) -> Self:
if copy is None:
copy = True

Expand All @@ -402,10 +401,10 @@ def _convert(arr):

return self.apply(_convert)

def replace_regex(self: T, **kwargs) -> T:
def replace_regex(self, **kwargs) -> Self:
return self.apply_with_block("_replace_regex", **kwargs)

def replace(self: T, to_replace, value, inplace: bool) -> T:
def replace(self, to_replace, value, inplace: bool) -> Self:
inplace = validate_bool_kwarg(inplace, "inplace")
assert np.ndim(value) == 0, value
# TODO "replace" is right now implemented on the blocks, we should move
Expand All @@ -415,12 +414,12 @@ def replace(self: T, to_replace, value, inplace: bool) -> T:
)

def replace_list(
self: T,
self,
src_list: list[Any],
dest_list: list[Any],
inplace: bool = False,
regex: bool = False,
) -> T:
) -> Self:
"""do a list replace"""
inplace = validate_bool_kwarg(inplace, "inplace")

Expand All @@ -432,7 +431,7 @@ def replace_list(
regex=regex,
)

def to_native_types(self: T, **kwargs) -> T:
def to_native_types(self, **kwargs) -> Self:
return self.apply(to_native_types, **kwargs)

@property
Expand All @@ -458,7 +457,7 @@ def is_view(self) -> bool:
def is_single_block(self) -> bool:
return len(self.arrays) == 1

def _get_data_subset(self: T, predicate: Callable) -> T:
def _get_data_subset(self, predicate: Callable) -> Self:
indices = [i for i, arr in enumerate(self.arrays) if predicate(arr)]
arrays = [self.arrays[i] for i in indices]
# TODO copy?
Expand All @@ -469,7 +468,7 @@ def _get_data_subset(self: T, predicate: Callable) -> T:
new_axes = [self._axes[0], new_cols]
return type(self)(arrays, new_axes, verify_integrity=False)

def get_bool_data(self: T, copy: bool = False) -> T:
def get_bool_data(self, copy: bool = False) -> Self:
"""
Select columns that are bool-dtype and object-dtype columns that are all-bool.

Expand All @@ -480,7 +479,7 @@ def get_bool_data(self: T, copy: bool = False) -> T:
"""
return self._get_data_subset(lambda x: x.dtype == np.dtype(bool))

def get_numeric_data(self: T, copy: bool = False) -> T:
def get_numeric_data(self, copy: bool = False) -> Self:
"""
Select columns that have a numeric dtype.

Expand All @@ -494,7 +493,7 @@ def get_numeric_data(self: T, copy: bool = False) -> T:
or getattr(arr.dtype, "_is_numeric", False)
)

def copy(self: T, deep: bool | Literal["all"] | None = True) -> T:
def copy(self, deep: bool | Literal["all"] | None = True) -> Self:
"""
Make deep or shallow copy of ArrayManager

Expand Down Expand Up @@ -531,7 +530,7 @@ def copy_func(ax):
return type(self)(new_arrays, new_axes, verify_integrity=False)

def reindex_indexer(
self: T,
self,
new_axis,
indexer,
axis: AxisInt,
Expand All @@ -542,7 +541,7 @@ def reindex_indexer(
only_slice: bool = False,
# ArrayManager specific keywords
use_na_proxy: bool = False,
) -> T:
) -> Self:
axis = self._normalize_axis(axis)
return self._reindex_indexer(
new_axis,
Expand All @@ -555,15 +554,15 @@ def reindex_indexer(
)

def _reindex_indexer(
self: T,
self,
new_axis,
indexer: npt.NDArray[np.intp] | None,
axis: AxisInt,
fill_value=None,
allow_dups: bool = False,
copy: bool | None = True,
use_na_proxy: bool = False,
) -> T:
) -> Self:
"""
Parameters
----------
Expand Down Expand Up @@ -634,11 +633,11 @@ def _reindex_indexer(
return type(self)(new_arrays, new_axes, verify_integrity=False)

def take(
self: T,
self,
indexer: npt.NDArray[np.intp],
axis: AxisInt = 1,
verify: bool = True,
) -> T:
) -> Self:
"""
Take items along any axis.
"""
Expand Down Expand Up @@ -926,7 +925,7 @@ def idelete(self, indexer) -> ArrayManager:
# --------------------------------------------------------------------
# Array-wise Operation

def grouped_reduce(self: T, func: Callable) -> T:
def grouped_reduce(self, func: Callable) -> Self:
"""
Apply grouped reduction function columnwise, returning a new ArrayManager.

Expand Down Expand Up @@ -965,7 +964,7 @@ def grouped_reduce(self: T, func: Callable) -> T:
# expected "List[Union[ndarray, ExtensionArray]]"
return type(self)(result_arrays, [index, columns]) # type: ignore[arg-type]

def reduce(self: T, func: Callable) -> T:
def reduce(self, func: Callable) -> Self:
"""
Apply reduction function column-wise, returning a single-row ArrayManager.

Expand Down
Loading