diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 62a3808d36ba2..0f2a90709fc24 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -2,6 +2,7 @@ Generic data algorithms. This module is experimental at the moment and not intended for public consumption """ +import abc import operator from textwrap import dedent from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union @@ -1065,7 +1066,7 @@ def _get_score(at): # --------------- # -class SelectN: +class SelectN(abc.ABC): def __init__(self, obj, n: int, keep: str): self.obj = obj self.n = n @@ -1090,6 +1091,9 @@ def is_valid_dtype_n_method(dtype) -> bool: is_numeric_dtype(dtype) and not is_complex_dtype(dtype) ) or needs_i8_conversion(dtype) + @abc.abstractmethod + def compute(self, method): ... + class SelectNSeries(SelectN): """ diff --git a/pandas/core/apply.py b/pandas/core/apply.py index a0351cb687d02..3818ca6ed5876 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -145,6 +145,9 @@ def get_result(self): """ compute the results """ # dispatch to agg if is_list_like(self.f) or is_dict_like(self.f): + if "axis" in self.kwds: + self.kwds.pop("axis") + return self.obj.aggregate(self.f, axis=self.axis, *self.args, **self.kwds) # all empty diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 6cb597ba75852..5cf04276b36fb 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -13,6 +13,7 @@ from pandas._libs import lib from pandas._typing import ArrayLike +from typing_extensions import Protocol from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError @@ -1053,6 +1054,19 @@ def __hash__(self): raise TypeError(f"unhashable type: {repr(type(self).__name__)}") +class OpsExtendable(Protocol): + + @classmethod + def _create_arithmetic_method(cls, op): ... + + @classmethod + def _create_comparison_method(cls, op): ... + + @classmethod + def _create_logical_method(cls, op): ... + + + class ExtensionOpsMixin: """ A base class for linking the operators to their dunder names. @@ -1065,41 +1079,41 @@ class ExtensionOpsMixin: """ @classmethod - def _add_arithmetic_ops(cls): - cls.__add__ = cls._create_arithmetic_method(operator.add) - cls.__radd__ = cls._create_arithmetic_method(ops.radd) - cls.__sub__ = cls._create_arithmetic_method(operator.sub) - cls.__rsub__ = cls._create_arithmetic_method(ops.rsub) - cls.__mul__ = cls._create_arithmetic_method(operator.mul) - cls.__rmul__ = cls._create_arithmetic_method(ops.rmul) - cls.__pow__ = cls._create_arithmetic_method(operator.pow) - cls.__rpow__ = cls._create_arithmetic_method(ops.rpow) - cls.__mod__ = cls._create_arithmetic_method(operator.mod) - cls.__rmod__ = cls._create_arithmetic_method(ops.rmod) - cls.__floordiv__ = cls._create_arithmetic_method(operator.floordiv) - cls.__rfloordiv__ = cls._create_arithmetic_method(ops.rfloordiv) - cls.__truediv__ = cls._create_arithmetic_method(operator.truediv) - cls.__rtruediv__ = cls._create_arithmetic_method(ops.rtruediv) - cls.__divmod__ = cls._create_arithmetic_method(divmod) - cls.__rdivmod__ = cls._create_arithmetic_method(ops.rdivmod) + def _add_arithmetic_ops(cls: OpsExtendable): + setattr(cls, "__add__", cls._create_arithmetic_method(operator.add)) + setattr(cls, "__radd__", cls._create_arithmetic_method(ops.radd)) + setattr(cls, "__sub__", cls._create_arithmetic_method(operator.sub)) + setattr(cls, "__rsub__", cls._create_arithmetic_method(ops.rsub)) + setattr(cls, "__mul__", cls._create_arithmetic_method(operator.mul)) + setattr(cls, "__rmul__", cls._create_arithmetic_method(ops.rmul)) + setattr(cls, "__pow__", cls._create_arithmetic_method(operator.pow)) + setattr(cls, "__rpow__", cls._create_arithmetic_method(ops.rpow)) + setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod)) + setattr(cls, "__rmod__", cls._create_arithmetic_method(ops.rmod)) + setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv)) + setattr(cls, "__rfloordiv__", cls._create_arithmetic_method(ops.rfloordiv)) + setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv)) + setattr(cls, "__rtruediv__", cls._create_arithmetic_method(ops.rtruediv)) + setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod)) + setattr(cls, "__rdivmod__", cls._create_arithmetic_method(ops.rdivmod)) @classmethod - def _add_comparison_ops(cls): - cls.__eq__ = cls._create_comparison_method(operator.eq) - cls.__ne__ = cls._create_comparison_method(operator.ne) - cls.__lt__ = cls._create_comparison_method(operator.lt) - cls.__gt__ = cls._create_comparison_method(operator.gt) - cls.__le__ = cls._create_comparison_method(operator.le) - cls.__ge__ = cls._create_comparison_method(operator.ge) + def _add_comparison_ops(cls: OpsExtendable): + setattr(cls, "__eq__", cls._create_comparison_method(operator.eq)) + setattr(cls, "__ne__", cls._create_comparison_method(operator.ne)) + setattr(cls, "__lt__", cls._create_comparison_method(operator.lt)) + setattr(cls, "__gt__", cls._create_comparison_method(operator.gt)) + setattr(cls, "__le__", cls._create_comparison_method(operator.le)) + setattr(cls, "__ge__", cls._create_comparison_method(operator.ge)) @classmethod - def _add_logical_ops(cls): - cls.__and__ = cls._create_logical_method(operator.and_) - cls.__rand__ = cls._create_logical_method(ops.rand_) - cls.__or__ = cls._create_logical_method(operator.or_) - cls.__ror__ = cls._create_logical_method(ops.ror_) - cls.__xor__ = cls._create_logical_method(operator.xor) - cls.__rxor__ = cls._create_logical_method(ops.rxor) + def _add_logical_ops(cls: OpsExtendable): + setattr(cls, "__and__", cls._create_logical_method(operator.and_)) + setattr(cls, "__rand__", cls._create_logical_method(ops.rand_)) + setattr(cls, "__or__", cls._create_logical_method(operator.or_)) + setattr(cls, "__ror__", cls._create_logical_method(ops.ror_)) + setattr(cls, "__xor__", cls._create_logical_method(operator.xor)) + setattr(cls, "__rxor__", cls._create_logical_method(ops.rxor)) class ExtensionScalarOpsMixin(ExtensionOpsMixin): diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 4fabd8f558fee..7b09879af4673 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1,9 +1,11 @@ +import abc from datetime import datetime, timedelta import operator -from typing import Any, Sequence, Type, Union, cast +from typing import Any, Sequence, Tuple, Type, Union, cast import warnings import numpy as np +from typing_extensions import Protocol from pandas._libs import NaT, NaTType, Timestamp, algos, iNaT, lib from pandas._libs.tslibs.c_timestamp import integer_op_not_supported @@ -212,7 +214,12 @@ def _check_compatible_with( raise AbstractMethodError(self) -class DatelikeOps: +class DatelikeOperable(Protocol): + + def _format_native_types(self, date_format, na_rep): ... + + +class DatelikeOps(abc.ABC): """ Common ops for DatetimeIndex/PeriodIndex, but not TimedeltaIndex. """ @@ -221,7 +228,7 @@ class DatelikeOps: URL="https://docs.python.org/3/library/datetime.html" "#strftime-and-strptime-behavior" ) - def strftime(self, date_format): + def strftime(self: DatelikeOperable, date_format): """ Convert to Index using specified date_format. @@ -260,6 +267,30 @@ def strftime(self, date_format): return result.astype(object) +class TimelikeOperable(Protocol): + + @property + def tz(self): ... + + @property + def dtype(self): ... + + @property + def inferred_freq(self): ... + + def __len__(self): ... + + def _round(self, freq, mode, ambiguous, nonexistent): ... + + def tz_localize(self, tz, ambiguous="raise", nonexistent="raise"): ... + + def view(self, dtype): ... + + def _maybe_mask_results(self, result, fill_value=iNaT, convert=None): ... + + def _simple_new(self, values: np.ndarray, dtype=None, freq=None): ... + + class TimelikeOps: """ Common ops for TimedeltaIndex/DatetimeIndex, but not PeriodIndex. @@ -368,7 +399,7 @@ class TimelikeOps: dtype: datetime64[ns] """ - def _round(self, freq, mode, ambiguous, nonexistent): + def _round(self: TimelikeOperable, freq, mode, ambiguous, nonexistent): # round the local times if is_datetime64tz_dtype(self): # operate on naive timestamps, then convert back to aware @@ -385,18 +416,18 @@ def _round(self, freq, mode, ambiguous, nonexistent): return self._simple_new(result, dtype=self.dtype) @Appender((_round_doc + _round_example).format(op="round")) - def round(self, freq, ambiguous="raise", nonexistent="raise"): + def round(self: TimelikeOperable, freq, ambiguous="raise", nonexistent="raise"): return self._round(freq, RoundTo.NEAREST_HALF_EVEN, ambiguous, nonexistent) @Appender((_round_doc + _floor_example).format(op="floor")) - def floor(self, freq, ambiguous="raise", nonexistent="raise"): + def floor(self: TimelikeOperable, freq, ambiguous="raise", nonexistent="raise"): return self._round(freq, RoundTo.MINUS_INFTY, ambiguous, nonexistent) @Appender((_round_doc + _ceil_example).format(op="ceil")) - def ceil(self, freq, ambiguous="raise", nonexistent="raise"): + def ceil(self: TimelikeOperable, freq, ambiguous="raise", nonexistent="raise"): return self._round(freq, RoundTo.PLUS_INFTY, ambiguous, nonexistent) - def _with_freq(self, freq): + def _with_freq(self: TimelikeOperable, freq): """ Helper to set our freq in-place, returning self to allow method chaining. @@ -425,6 +456,73 @@ def _with_freq(self, freq): return self +class DatetimeLikeArrayProtocol(Protocol): + _data: Any + _freq: Any + _recognized_scalars: Any + _resolution: Any + _scalar_type: Any + freq: Any + dtype: Any + asi8: Any + ndim: int + + def __init__(self, values, dtype, freq=None, copy=False): ... + + def __add__(self, other): ... + + @property + def _box_func(self): ... + + @property + def size(self): ... + + @property + def freqstr(self) -> str: ... + + @property + def _isnan(self) -> bool: ... + + @property + def _hasnans(self) -> bool: ... + + def _generate_range(self, start, end, periods, freq, closed_or_fields): ... + + def shape(self) -> Tuple[int, ...]: ... + + def _simple_new(self, values: np.ndarray, dtype=None, freq=None): ... + + def _box_values(self, values): ... + + def _format_native_types(self, na_rep="NaT", date_format=None, **kwargs): ... + + def _check_compatible_with(self, other, setitem: bool = False): ... + + def _unbox_scalar(self, value): ... + + def _validate_fill_value(self, fill_value): ... + + def _add_nat(self): ... + + def _add_offset(self, other): ... + + def _add_datetimelike_scalar(self, other): ... + + def _time_shift(self, other): ... + + def _add_datetime_arraylike(self, other): ... + + def _add_timedelta_arraylike(self, other): ... + + def _addsub_int_array(self, other, op): ... + + def _addsub_object_array(self, other, op): ... + + def copy(self): ... + + def isna(self) -> bool: ... + + class DatetimeLikeArrayMixin(ExtensionOpsMixin, AttributesMixin, ExtensionArray): """ Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray @@ -438,19 +536,19 @@ class DatetimeLikeArrayMixin(ExtensionOpsMixin, AttributesMixin, ExtensionArray) """ @property - def ndim(self) -> int: + def ndim(self: DatetimeLikeArrayProtocol) -> int: return self._data.ndim @property - def shape(self): + def shape(self: DatetimeLikeArrayProtocol): return self._data.shape - def reshape(self, *args, **kwargs): + def reshape(self: DatetimeLikeArrayProtocol, *args, **kwargs): # Note: we drop any freq data = self._data.reshape(*args, **kwargs) return type(self)(data, dtype=self.dtype) - def ravel(self, *args, **kwargs): + def ravel(self: DatetimeLikeArrayProtocol, *args, **kwargs): # Note: we drop any freq data = self._data.ravel(*args, **kwargs) return type(self)(data, dtype=self.dtype) @@ -468,11 +566,11 @@ def _box_values(self, values): """ return lib.map_infer(values, self._box_func) - def __iter__(self): + def __iter__(self: DatetimeLikeArrayProtocol): return (self._box_func(v) for v in self.asi8) @property - def asi8(self) -> np.ndarray: + def asi8(self: DatetimeLikeArrayProtocol) -> np.ndarray: """ Integer representation of the values. @@ -487,7 +585,7 @@ def asi8(self) -> np.ndarray: # ---------------------------------------------------------------- # Rendering Methods - def _format_native_types(self, na_rep="NaT", date_format=None): + def _format_native_types(self: DatetimeLikeArrayProtocol, na_rep="NaT", date_format=None): """ Helper method for astype when converting to strings. @@ -505,7 +603,7 @@ def _formatter(self, boxed=False): # Array-Like / EA-Interface Methods @property - def nbytes(self): + def nbytes(self: DatetimeLikeArrayProtocol): return self._data.nbytes def __array__(self, dtype=None) -> np.ndarray: @@ -515,14 +613,14 @@ def __array__(self, dtype=None) -> np.ndarray: return self._data @property - def size(self) -> int: + def size(self: DatetimeLikeArrayProtocol) -> int: """The number of elements in this array.""" return np.prod(self.shape) def __len__(self) -> int: return len(self._data) - def __getitem__(self, key): + def __getitem__(self: DatetimeLikeArrayProtocol, key): """ This getitem defers to the underlying array, which by-definition can only handle list-likes, slices, and integer scalars @@ -631,7 +729,7 @@ def _maybe_clear_freq(self): # DatetimeArray and TimedeltaArray pass - def astype(self, dtype, copy=True): + def astype(self: DatetimeLikeArrayProtocol, dtype, copy=True): # Some notes on cases we don't have to handle here in the base class: # 1. PeriodArray.astype handles period -> period # 2. DatetimeArray.astype handles conversion between tz. @@ -669,7 +767,7 @@ def astype(self, dtype, copy=True): else: return np.asarray(self, dtype=dtype) - def view(self, dtype=None): + def view(self: DatetimeLikeArrayProtocol, dtype=None): if dtype is None or dtype is self.dtype: return type(self)(self._data, dtype=self.dtype) return self._data.view(dtype=dtype) @@ -677,11 +775,11 @@ def view(self, dtype=None): # ------------------------------------------------------------------ # ExtensionArray Interface - def unique(self): + def unique(self: DatetimeLikeArrayProtocol): result = unique1d(self.asi8) return type(self)(result, dtype=self.dtype) - def _validate_fill_value(self, fill_value): + def _validate_fill_value(self: DatetimeLikeArrayProtocol, fill_value): """ If a fill_value is passed to `take` convert it to an i8 representation, raising ValueError if this is not possible. @@ -710,7 +808,7 @@ def _validate_fill_value(self, fill_value): ) return fill_value - def take(self, indices, allow_fill=False, fill_value=None): + def take(self: DatetimeLikeArrayProtocol, indices, allow_fill=False, fill_value=None): if allow_fill: fill_value = self._validate_fill_value(fill_value) @@ -748,22 +846,22 @@ def _concat_same_type(cls, to_concat): return cls._simple_new(values, dtype=dtype, freq=new_freq) - def copy(self): + def copy(self: DatetimeLikeArrayProtocol): values = self.asi8.copy() return type(self)._simple_new(values, dtype=self.dtype, freq=self.freq) - def _values_for_factorize(self): + def _values_for_factorize(self: DatetimeLikeArrayProtocol): return self.asi8, iNaT @classmethod - def _from_factorized(cls, values, original): + def _from_factorized(cls: DatetimeLikeArrayProtocol, values, original): return cls(values, dtype=original.dtype) - def _values_for_argsort(self): + def _values_for_argsort(self: DatetimeLikeArrayProtocol): return self._data @Appender(ExtensionArray.shift.__doc__) - def shift(self, periods=1, fill_value=None, axis=0): + def shift(self: DatetimeLikeArrayProtocol, periods=1, fill_value=None, axis=0): if not self.size or periods == 0: return self.copy() @@ -799,7 +897,7 @@ def shift(self, periods=1, fill_value=None, axis=0): # These are not part of the EA API, but we implement them because # pandas assumes they're there. - def searchsorted(self, value, side="left", sorter=None): + def searchsorted(self: DatetimeLikeArrayProtocol, value, side="left", sorter=None): """ Find indices where elements should be inserted to maintain order. @@ -859,7 +957,7 @@ def searchsorted(self, value, side="left", sorter=None): # TODO: Use datetime64 semantics for sorting, xref GH#29844 return self.asi8.searchsorted(value, side=side, sorter=sorter) - def repeat(self, repeats, *args, **kwargs): + def repeat(self: DatetimeLikeArrayProtocol, repeats, *args, **kwargs): """ Repeat elements of an array. @@ -871,7 +969,7 @@ def repeat(self, repeats, *args, **kwargs): values = self._data.repeat(repeats) return type(self)(values.view("i8"), dtype=self.dtype) - def value_counts(self, dropna=False): + def value_counts(self: DatetimeLikeArrayProtocol, dropna=False): """ Return a Series containing counts of unique values. @@ -912,24 +1010,24 @@ def map(self, mapper): # ------------------------------------------------------------------ # Null Handling - def isna(self): + def isna(self: DatetimeLikeArrayProtocol): return self._isnan @property # NB: override with cache_readonly in immutable subclasses - def _isnan(self): + def _isnan(self: DatetimeLikeArrayProtocol): """ return if each value is nan """ return self.asi8 == iNaT @property # NB: override with cache_readonly in immutable subclasses - def _hasnans(self): + def _hasnans(self: DatetimeLikeArrayProtocol): """ return if I have any nans; enables various perf speedups """ return bool(self._isnan.any()) - def _maybe_mask_results(self, result, fill_value=iNaT, convert=None): + def _maybe_mask_results(self: DatetimeLikeArrayProtocol, result, fill_value=iNaT, convert=None): """ Parameters ---------- @@ -954,7 +1052,7 @@ def _maybe_mask_results(self, result, fill_value=iNaT, convert=None): result[self._isnan] = fill_value return result - def fillna(self, value=None, method=None, limit=None): + def fillna(self: DatetimeLikeArrayProtocol, value=None, method=None, limit=None): # TODO(GH-20300): remove this # Just overriding to ensure that we avoid an astype(object). # Either 20300 or a `_values_for_fillna` would avoid this duplication. @@ -1005,14 +1103,14 @@ def fillna(self, value=None, method=None, limit=None): # Frequency Properties/Methods @property - def freq(self): + def freq(self: DatetimeLikeArrayProtocol): """ Return the frequency object if it is set, otherwise None. """ return self._freq @freq.setter - def freq(self, value): + def freq(self: DatetimeLikeArrayProtocol, value): if value is not None: value = frequencies.to_offset(value) self._validate_frequency(self, value) @@ -1020,7 +1118,7 @@ def freq(self, value): self._freq = value @property - def freqstr(self): + def freqstr(self: DatetimeLikeArrayProtocol): """ Return the frequency object as a string if its set, otherwise None. """ @@ -1029,7 +1127,7 @@ def freqstr(self): return self.freq.freqstr @property # NB: override with cache_readonly in immutable subclasses - def inferred_freq(self): + def inferred_freq(self: DatetimeLikeArrayProtocol): """ Tryies to return a string representing a frequency guess, generated by infer_freq. Returns None if it can't autodetect the @@ -1043,11 +1141,11 @@ def inferred_freq(self): return None @property # NB: override with cache_readonly in immutable subclasses - def _resolution(self): + def _resolution(self: DatetimeLikeArrayProtocol): return frequencies.Resolution.get_reso_from_freq(self.freqstr) @property # NB: override with cache_readonly in immutable subclasses - def resolution(self): + def resolution(self: DatetimeLikeArrayProtocol): """ Returns day, hour, minute, second, millisecond or microsecond """ @@ -1099,15 +1197,15 @@ def _validate_frequency(cls, index, freq, **kwargs): # see GH#23789 @property - def _is_monotonic_increasing(self): + def _is_monotonic_increasing(self: DatetimeLikeArrayProtocol): return algos.is_monotonic(self.asi8, timelike=True)[0] @property - def _is_monotonic_decreasing(self): + def _is_monotonic_decreasing(self: DatetimeLikeArrayProtocol): return algos.is_monotonic(self.asi8, timelike=True)[1] @property - def _is_unique(self): + def _is_unique(self: DatetimeLikeArrayProtocol): return len(unique1d(self.asi8)) == len(self) # ------------------------------------------------------------------ @@ -1149,7 +1247,7 @@ def _sub_period(self, other): def _add_offset(self, offset): raise AbstractMethodError(self) - def _add_timedeltalike_scalar(self, other): + def _add_timedeltalike_scalar(self: DatetimeLikeArrayProtocol, other): """ Add a delta of a timedeltalike @@ -1179,7 +1277,7 @@ def _add_timedeltalike_scalar(self, other): return type(self)(new_values, dtype=self.dtype, freq=new_freq) return type(self)(new_values, dtype=self.dtype)._with_freq("infer") - def _add_timedelta_arraylike(self, other): + def _add_timedelta_arraylike(self: DatetimeLikeArrayProtocol, other): """ Add a delta of a TimedeltaIndex @@ -1209,7 +1307,7 @@ def _add_timedelta_arraylike(self, other): return type(self)(new_values, dtype=self.dtype)._with_freq("infer") - def _add_nat(self): + def _add_nat(self: DatetimeLikeArrayProtocol): """ Add pd.NaT to self """ @@ -1224,7 +1322,7 @@ def _add_nat(self): result.fill(iNaT) return type(self)(result, dtype=self.dtype, freq=None) - def _sub_nat(self): + def _sub_nat(self: DatetimeLikeArrayProtocol): """ Subtract pd.NaT from self """ @@ -1238,7 +1336,7 @@ def _sub_nat(self): result.fill(iNaT) return result.view("timedelta64[ns]") - def _sub_period_array(self, other): + def _sub_period_array(self: DatetimeLikeArrayProtocol, other): """ Subtract a Period Array/Index from self. This is only valid if self is itself a Period Array/Index, raises otherwise. Both objects must @@ -1274,7 +1372,7 @@ def _sub_period_array(self, other): new_values[mask] = NaT return new_values - def _addsub_object_array(self, other: np.ndarray, op): + def _addsub_object_array(self: DatetimeLikeArrayProtocol, other: np.ndarray, op): """ Add or subtract array-like of DateOffset objects @@ -1305,7 +1403,7 @@ def _addsub_object_array(self, other: np.ndarray, op): result = extract_array(result, extract_numpy=True).reshape(self.shape) return result - def _time_shift(self, periods, freq=None): + def _time_shift(self: DatetimeLikeArrayProtocol, periods, freq=None): """ Shift each value by `periods`. @@ -1343,7 +1441,7 @@ def _time_shift(self, periods, freq=None): return self._generate_range(start=start, end=end, periods=None, freq=self.freq) @unpack_zerodim_and_defer("__add__") - def __add__(self, other): + def __add__(self: DatetimeLikeArrayProtocol, other): # scalar others if other is NaT: @@ -1390,12 +1488,12 @@ def __add__(self, other): return TimedeltaArray(result) return result - def __radd__(self, other): + def __radd__(self: DatetimeLikeArrayProtocol, other): # alias for __add__ return self.__add__(other) @unpack_zerodim_and_defer("__sub__") - def __sub__(self, other): + def __sub__(self: DatetimeLikeArrayProtocol, other): # scalar others if other is NaT: @@ -1444,7 +1542,7 @@ def __sub__(self, other): return TimedeltaArray(result) return result - def __rsub__(self, other): + def __rsub__(self: DatetimeLikeArrayProtocol, other): if is_datetime64_any_dtype(other) and is_timedelta64_dtype(self.dtype): # ndarray[datetime64] cannot be subtracted from self, so # we need to wrap in DatetimeArray/Index and flip the operation @@ -1480,7 +1578,7 @@ def __rsub__(self, other): return -(self - other) - def __iadd__(self, other): # type: ignore + def __iadd__(self: DatetimeLikeArrayProtocol, other): # type: ignore result = self + other self[:] = result[:] @@ -1489,7 +1587,7 @@ def __iadd__(self, other): # type: ignore self._freq = result._freq return self - def __isub__(self, other): # type: ignore + def __isub__(self: DatetimeLikeArrayProtocol, other): # type: ignore result = self - other self[:] = result[:] @@ -1501,14 +1599,14 @@ def __isub__(self, other): # type: ignore # -------------------------------------------------------------- # Reductions - def _reduce(self, name, axis=0, skipna=True, **kwargs): + def _reduce(self: DatetimeLikeArrayProtocol, name, axis=0, skipna=True, **kwargs): op = getattr(self, name, None) if op: return op(skipna=skipna, **kwargs) else: return super()._reduce(name, skipna, **kwargs) - def min(self, axis=None, skipna=True, *args, **kwargs): + def min(self: DatetimeLikeArrayProtocol, axis=None, skipna=True, *args, **kwargs): """ Return the minimum value of the Array or minimum along an axis. @@ -1528,7 +1626,7 @@ def min(self, axis=None, skipna=True, *args, **kwargs): return NaT return self._box_func(result) - def max(self, axis=None, skipna=True, *args, **kwargs): + def max(self: DatetimeLikeArrayProtocol, axis=None, skipna=True, *args, **kwargs): """ Return the maximum value of the Array or maximum along an axis. @@ -1560,7 +1658,7 @@ def max(self, axis=None, skipna=True, *args, **kwargs): # Don't have to worry about NA `result`, since no NA went in. return self._box_func(result) - def mean(self, skipna=True): + def mean(self: DatetimeLikeArrayProtocol, skipna=True): """ Return the mean value of the Array. diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index 363286704ba95..cbcc7668eb4ee 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -4,18 +4,41 @@ SeriesGroupBy and the DataFrameGroupBy objects. """ import collections +from typing import List + +from typing_extensions import Protocol from pandas.core.dtypes.common import is_list_like, is_scalar +from pandas._typing import FrameOrSeries + + OutputKey = collections.namedtuple("OutputKey", ["label", "position"]) +class Groupable(Protocol): + + # TODO: These probably shouldn't both be FrameOrSeries + def __init__(self, subset: FrameOrSeries, groupby: FrameOrSeries, parent: "Groupable", **kwargs): ... + + @property + def obj(self) -> FrameOrSeries: ... + + @property + def _attributes(self) -> List[str]: ... + + @property + def _groupby(self) -> FrameOrSeries: ... + + def _reset_cache(self) -> None: ... + + class GroupByMixin: """ Provide the groupby facilities to the mixed object. """ - def _gotitem(self, key, ndim, subset=None): + def _gotitem(self: Groupable, key, ndim, subset=None): """ Sub-classes to define. Return a sliced object. diff --git a/pandas/io/common.py b/pandas/io/common.py index 0fce8f5382686..9c6f729a3478a 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -497,9 +497,13 @@ def __init__( super().__init__(file, mode, zipfile.ZIP_DEFLATED, **kwargs) def write(self, data): - archive_name = self.filename if self.archive_name is not None: archive_name = self.archive_name + elif self.filename is not None: + archive_name = self.filename + else: + raise RuntimeError("No filename to write to!") + super().writestr(archive_name, data) @property diff --git a/pandas/io/formats/csvs.py b/pandas/io/formats/csvs.py index 091f7662630ff..2213ac056cc5e 100644 --- a/pandas/io/formats/csvs.py +++ b/pandas/io/formats/csvs.py @@ -237,6 +237,11 @@ def _save_header(self): if not (has_aliases or self.header): return if has_aliases: + # TODO: type checking here is tricky because header is annotated as + # a Union[bool, Sequence[Hashable]] but we actually accept "sequence" types + # that don't inherit from abc.Sequence (ndarray, ABCIndex) + assert not isinstance(header, bool) + if len(header) != len(cols): raise ValueError( f"Writing {len(cols)} cols but got {len(header)} aliases" @@ -268,12 +273,15 @@ def _save_header(self): # given a string for a DF with Index index_label = [index_label] - encoded_labels = list(index_label) + # TODO: mismatch here because encoded_labels is a Sequence[str] + # but we fill with Sequence[Hashable]; need to clean up handling + # of non-None / non-str contained objects + encoded_labels = list(index_label) # type: ignore else: encoded_labels = [] if not has_mi_columns or has_aliases: - encoded_labels += list(write_cols) + encoded_labels += list(write_cols) # type: ignore writer.writerow(encoded_labels) else: # write out the mi diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index 5aab5b814bae7..f623aad72598a 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -202,6 +202,7 @@ class TestPDApi(Base): "_testing", "_tslib", "_typing", + "_typing_extensions", "_version", ] diff --git a/pandas/util/_decorators.py b/pandas/util/_decorators.py index 71d02db10c7ba..485ba898c6bbf 100644 --- a/pandas/util/_decorators.py +++ b/pandas/util/_decorators.py @@ -293,6 +293,8 @@ def decorate(func): allow_args = allowed_args else: spec = inspect.getfullargspec(func) + + assert spec.defaults is not None # TODO: this might be a bug allow_args = spec.args[: -len(spec.defaults)] @wraps(func)