Skip to content

PERF: concat_same_type for PeriodDtype #52290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pandas/_libs/arrays.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ from typing import Sequence
import numpy as np

from pandas._typing import (
AxisInt,
DtypeObj,
Self,
Shape,
)

Expand Down Expand Up @@ -32,3 +34,7 @@ class NDArrayBacked:
def ravel(self, order=...): ...
@property
def T(self): ...
@classmethod
def _concat_same_type(
cls, to_concat: Sequence[Self], axis: AxisInt = ...
) -> Self: ...
7 changes: 7 additions & 0 deletions pandas/_libs/arrays.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,10 @@ cdef class NDArrayBacked:
def transpose(self, *axes):
res_values = self._ndarray.transpose(*axes)
return self._from_backing_data(res_values)

@classmethod
def _concat_same_type(cls, to_concat, axis=0):
# NB: We are assuming at this point that dtypes all match
new_values = [obj._ndarray for obj in to_concat]
new_arr = cnp.PyArray_Concatenate(new_values, axis)
return to_concat[0]._from_backing_data(new_arr)
10 changes: 4 additions & 6 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,11 @@ def _concat_same_type(
to_concat: Sequence[Self],
axis: AxisInt = 0,
) -> Self:
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
raise ValueError("to_concat must have the same dtype (tz)", dtypes)
if not lib.dtypes_all_equal([x.dtype for x in to_concat]):
dtypes = {str(x.dtype) for x in to_concat}
raise ValueError("to_concat must have the same dtype", dtypes)

new_values = [x._ndarray for x in to_concat]
new_arr = np.concatenate(new_values, axis=axis)
return to_concat[0]._from_backing_data(new_arr)
return super()._concat_same_type(to_concat, axis=axis)

@doc(ExtensionArray.searchsorted)
def searchsorted(
Expand Down
71 changes: 34 additions & 37 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def new_meth(self, *args, **kwargs):
return cast(F, new_meth)


class DatetimeLikeArrayMixin(OpsMixin, NDArrayBackedExtensionArray):
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
class DatetimeLikeArrayMixin( # type: ignore[misc]
OpsMixin, NDArrayBackedExtensionArray
):
"""
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray

Expand Down Expand Up @@ -505,42 +509,6 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
# are present in this file.
return super().view(dtype)

# ------------------------------------------------------------------
# ExtensionArray Interface

@classmethod
def _concat_same_type(
cls,
to_concat: Sequence[Self],
axis: AxisInt = 0,
) -> Self:
new_obj = super()._concat_same_type(to_concat, axis)

obj = to_concat[0]
dtype = obj.dtype

new_freq = None
if isinstance(dtype, PeriodDtype):
new_freq = obj.freq
elif axis == 0:
# GH 3232: If the concat result is evenly spaced, we can retain the
# original frequency
to_concat = [x for x in to_concat if len(x)]

if obj.freq is not None and all(x.freq == obj.freq for x in to_concat):
pairs = zip(to_concat[:-1], to_concat[1:])
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
new_freq = obj.freq

new_obj._freq = new_freq
return new_obj

def copy(self, order: str = "C") -> Self:
# error: Unexpected keyword argument "order" for "copy"
new_obj = super().copy(order=order) # type: ignore[call-arg]
new_obj._freq = self.freq
return new_obj

# ------------------------------------------------------------------
# Validation Methods
# TODO: try to de-duplicate these, ensure identical behavior
Expand Down Expand Up @@ -2085,6 +2053,7 @@ def _with_freq(self, freq) -> Self:
return arr

# --------------------------------------------------------------
# ExtensionArray Interface

def factorize(
self,
Expand All @@ -2102,6 +2071,34 @@ def factorize(
# FIXME: shouldn't get here; we are ignoring sort
return super().factorize(use_na_sentinel=use_na_sentinel)

@classmethod
def _concat_same_type(
cls,
to_concat: Sequence[Self],
axis: AxisInt = 0,
) -> Self:
new_obj = super()._concat_same_type(to_concat, axis)

obj = to_concat[0]

if axis == 0:
# GH 3232: If the concat result is evenly spaced, we can retain the
# original frequency
to_concat = [x for x in to_concat if len(x)]

if obj.freq is not None and all(x.freq == obj.freq for x in to_concat):
pairs = zip(to_concat[:-1], to_concat[1:])
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
new_freq = obj.freq
new_obj._freq = new_freq
return new_obj

def copy(self, order: str = "C") -> Self:
# error: Unexpected keyword argument "order" for "copy"
new_obj = super().copy(order=order) # type: ignore[call-arg]
new_obj._freq = self.freq
return new_obj


# -------------------------------------------------------------------
# Shared Constructor Helpers
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def f(self):
return property(f)


class DatetimeArray(dtl.TimelikeOps, dtl.DatelikeOps):
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
class DatetimeArray(dtl.TimelikeOps, dtl.DatelikeOps): # type: ignore[misc]
"""
Pandas ExtensionArray for tz-naive or tz-aware datetime data.

Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
)


class PandasArray(
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
class PandasArray( # type: ignore[misc]
OpsMixin,
NDArrayBackedExtensionArray,
ObjectStringArrayMixin,
Expand Down
9 changes: 7 additions & 2 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def f(self):
return property(f)


class PeriodArray(dtl.DatelikeOps, libperiod.PeriodMixin):
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
class PeriodArray(dtl.DatelikeOps, libperiod.PeriodMixin): # type: ignore[misc]
"""
Pandas ExtensionArray for storing Period data.

Expand Down Expand Up @@ -263,7 +265,10 @@ def _from_sequence(
validate_dtype_freq(scalars.dtype, freq)
if copy:
scalars = scalars.copy()
return scalars
# error: Incompatible return value type
# (got "Union[Sequence[Optional[Period]], Union[Union[ExtensionArray,
# ndarray[Any, Any]], Index, Series]]", expected "PeriodArray")
return scalars # type: ignore[return-value]

periods = np.asarray(scalars, dtype=object)

Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def tolist(self):
return list(self.to_numpy())


class StringArray(BaseStringArray, PandasArray):
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
class StringArray(BaseStringArray, PandasArray): # type: ignore[misc]
"""
Extension array for string data.

Expand Down
9 changes: 3 additions & 6 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,12 +956,9 @@ def __eq__(self, other: Any) -> bool:
elif isinstance(other, PeriodDtype):
# For freqs that can be held by a PeriodDtype, this check is
# equivalent to (and much faster than) self.freq == other.freq
sfreq = self.freq
ofreq = other.freq
return (
sfreq.n == ofreq.n
and sfreq._period_dtype_code == ofreq._period_dtype_code
)
sfreq = self._freq
ofreq = other._freq
return sfreq.n == ofreq.n and self._dtype_code == other._dtype_code

return False

Expand Down