Skip to content

REF: share more EA methods #36154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly
from pandas.util._decorators import cache_readonly, doc

from pandas.core.algorithms import take, unique
from pandas.core.algorithms import searchsorted, take, unique
from pandas.core.array_algos.transforms import shift
from pandas.core.arrays.base import ExtensionArray

_T = TypeVar("_T", bound="NDArrayBackedExtensionArray")
Expand Down Expand Up @@ -120,3 +121,31 @@ def repeat(self: _T, repeats, axis=None) -> _T:
def unique(self: _T) -> _T:
new_data = unique(self._ndarray)
return self._from_backing_data(new_data)

@classmethod
@doc(ExtensionArray._concat_same_type)
def _concat_same_type(cls, to_concat, axis: int = 0):
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
raise ValueError("to_concat must have the same dtype (tz)", dtypes)

new_values = [x._ndarray for x in to_concat]
new_values = np.concatenate(new_values, axis=axis)
return to_concat[0]._from_backing_data(new_values)

@doc(ExtensionArray.searchsorted)
def searchsorted(self, value, side="left", sorter=None):
return searchsorted(self._ndarray, value, side=side, sorter=sorter)

@doc(ExtensionArray.shift)
def shift(self, periods=1, fill_value=None, axis=0):

fill_value = self._validate_shift_value(fill_value)
new_values = shift(self._ndarray, periods, axis, fill_value)

return self._from_backing_data(new_values)

def _validate_shift_value(self, fill_value):
# TODO: after deprecation in datetimelikearraymixin is enforced,
# we can remove this and ust validate_fill_value directly
return self._validate_fill_value(fill_value)
126 changes: 6 additions & 120 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@
from pandas.core.accessor import PandasDelegate, delegate_names
import pandas.core.algorithms as algorithms
from pandas.core.algorithms import _get_data_algo, factorize, take_1d, unique1d
from pandas.core.array_algos.transforms import shift
from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.base import (
ExtensionArray,
NoNewAttributesMixin,
Expand Down Expand Up @@ -1193,35 +1192,6 @@ def map(self, mapper):
__le__ = _cat_compare_op(operator.le)
__ge__ = _cat_compare_op(operator.ge)

def shift(self, periods, fill_value=None):
"""
Shift Categorical by desired number of periods.

Parameters
----------
periods : int
Number of periods to move, can be positive or negative
fill_value : object, optional
The scalar value to use for newly introduced missing values.

.. versionadded:: 0.24.0

Returns
-------
shifted : Categorical
"""
# since categoricals always have ndim == 1, an axis parameter
# doesn't make any sense here.
codes = self.codes
if codes.ndim > 1:
raise NotImplementedError("Categorical with ndim > 1.")

fill_value = self._validate_fill_value(fill_value)

codes = shift(codes, periods, axis=0, fill_value=fill_value)

return self._constructor(codes, dtype=self.dtype, fastpath=True)

def _validate_fill_value(self, fill_value):
"""
Convert a user-facing fill_value to a representation to use with our
Expand Down Expand Up @@ -1383,20 +1353,6 @@ def notna(self):

notnull = notna

def dropna(self):
"""
Return the Categorical without null values.

Missing values (-1 in .codes) are detected.

Returns
-------
valid : Categorical
"""
result = self[self.notna()]

return result

def value_counts(self, dropna=True):
"""
Return a Series containing counts of each category.
Expand Down Expand Up @@ -1749,81 +1705,6 @@ def fillna(self, value=None, method=None, limit=None):

return self._constructor(codes, dtype=self.dtype, fastpath=True)

def take(self: _T, indexer, allow_fill: bool = False, fill_value=None) -> _T:
"""
Take elements from the Categorical.

Parameters
----------
indexer : sequence of int
The indices in `self` to take. The meaning of negative values in
`indexer` depends on the value of `allow_fill`.
allow_fill : bool, default False
How to handle negative values in `indexer`.

* False: negative values in `indices` indicate positional indices
from the right. This is similar to
:func:`numpy.take`.

* True: negative values in `indices` indicate missing values
(the default). These values are set to `fill_value`. Any other
other negative values raise a ``ValueError``.

.. versionchanged:: 1.0.0

Default value changed from ``True`` to ``False``.

fill_value : object
The value to use for `indices` that are missing (-1), when
``allow_fill=True``. This should be the category, i.e. a value
in ``self.categories``, not a code.

Returns
-------
Categorical
This Categorical will have the same categories and ordered as
`self`.

See Also
--------
Series.take : Similar method for Series.
numpy.ndarray.take : Similar method for NumPy arrays.

Examples
--------
>>> cat = pd.Categorical(['a', 'a', 'b'])
>>> cat
['a', 'a', 'b']
Categories (2, object): ['a', 'b']

Specify ``allow_fill==False`` to have negative indices mean indexing
from the right.

>>> cat.take([0, -1, -2], allow_fill=False)
['a', 'b', 'a']
Categories (2, object): ['a', 'b']

With ``allow_fill=True``, indices equal to ``-1`` mean "missing"
values that should be filled with the `fill_value`, which is
``np.nan`` by default.

>>> cat.take([0, -1, -1], allow_fill=True)
['a', NaN, NaN]
Categories (2, object): ['a', 'b']

The fill value can be specified.

>>> cat.take([0, -1, -1], allow_fill=True, fill_value='a')
['a', 'a', 'a']
Categories (2, object): ['a', 'b']

Specifying a fill value that's not in ``self.categories``
will raise a ``ValueError``.
"""
return NDArrayBackedExtensionArray.take(
self, indexer, allow_fill=allow_fill, fill_value=fill_value
)

# ------------------------------------------------------------------
# NDArrayBackedExtensionArray compat

Expand Down Expand Up @@ -1861,6 +1742,9 @@ def __contains__(self, key) -> bool:

return contains(self, key, container=self._codes)

# ------------------------------------------------------------------
# Rendering Methods

def _tidy_repr(self, max_vals=10, footer=True) -> str:
"""
a short repr displaying only max_vals and an optional (but default
Expand Down Expand Up @@ -1959,6 +1843,8 @@ def __repr__(self) -> str:

return result

# ------------------------------------------------------------------

def _maybe_coerce_indexer(self, indexer):
"""
return an indexer coerced to the codes dtype
Expand Down
28 changes: 7 additions & 21 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@

from pandas.core import missing, nanops, ops
from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts
from pandas.core.array_algos.transforms import shift
from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray
from pandas.core.arrays.base import ExtensionArray, ExtensionOpsMixin
from pandas.core.arrays.base import ExtensionOpsMixin
import pandas.core.common as com
from pandas.core.construction import array, extract_array
from pandas.core.indexers import check_array_indexer
Expand Down Expand Up @@ -672,18 +671,11 @@ def view(self, dtype=None):

@classmethod
def _concat_same_type(cls, to_concat, axis: int = 0):

# do not pass tz to set because tzlocal cannot be hashed
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
raise ValueError("to_concat must have the same dtype (tz)", dtypes)
new_obj = super()._concat_same_type(to_concat, axis)

obj = to_concat[0]
dtype = obj.dtype

i8values = [x.asi8 for x in to_concat]
values = np.concatenate(i8values, axis=axis)

new_freq = None
if is_period_dtype(dtype):
new_freq = obj.freq
Expand All @@ -697,11 +689,13 @@ def _concat_same_type(cls, to_concat, axis: int = 0):
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
new_freq = obj.freq

return cls._simple_new(values, dtype=dtype, freq=new_freq)
new_obj._freq = new_freq
return new_obj

def copy(self: DatetimeLikeArrayT) -> DatetimeLikeArrayT:
values = self.asi8.copy()
return type(self)._simple_new(values, dtype=self.dtype, freq=self.freq)
new_obj = super().copy()
new_obj._freq = self.freq
return new_obj

def _values_for_factorize(self):
return self.asi8, iNaT
Expand All @@ -713,14 +707,6 @@ def _from_factorized(cls, values, original):
def _values_for_argsort(self):
return self._data

@Appender(ExtensionArray.shift.__doc__)
def shift(self, periods=1, fill_value=None, axis=0):

fill_value = self._validate_shift_value(fill_value)
new_values = shift(self._data, periods, axis, fill_value)

return type(self)._simple_new(new_values, dtype=self.dtype)

# ------------------------------------------------------------------
# Validation Methods
# TODO: try to de-duplicate these, ensure identical behavior
Expand Down
12 changes: 1 addition & 11 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pandas._libs import lib
from pandas._typing import Scalar
from pandas.compat.numpy import function as nv
from pandas.util._decorators import doc
from pandas.util._validators import validate_fillna_kwargs

from pandas.core.dtypes.dtypes import ExtensionDtype
Expand All @@ -16,10 +15,9 @@

from pandas import compat
from pandas.core import nanops, ops
from pandas.core.algorithms import searchsorted
from pandas.core.array_algos import masked_reductions
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.arrays.base import ExtensionArray, ExtensionOpsMixin
from pandas.core.arrays.base import ExtensionOpsMixin
from pandas.core.construction import extract_array
from pandas.core.indexers import check_array_indexer
from pandas.core.missing import backfill_1d, pad_1d
Expand Down Expand Up @@ -189,10 +187,6 @@ def _from_sequence(cls, scalars, dtype=None, copy: bool = False) -> "PandasArray
def _from_factorized(cls, values, original) -> "PandasArray":
return cls(values)

@classmethod
def _concat_same_type(cls, to_concat) -> "PandasArray":
return cls(np.concatenate(to_concat))

def _from_backing_data(self, arr: np.ndarray) -> "PandasArray":
return type(self)(arr)

Expand Down Expand Up @@ -423,10 +417,6 @@ def to_numpy(

return result

@doc(ExtensionArray.searchsorted)
def searchsorted(self, value, side="left", sorter=None):
return searchsorted(self.to_numpy(), value, side=side, sorter=sorter)

# ------------------------------------------------------------------------
# Ops

Expand Down