Skip to content

Commit b522e7e

Browse files
authored
PERF: concat_same_type for PeriodDtype (#52290)
* PERF: concat_same_type for PeriodDtype * mypy fixup
1 parent 261d425 commit b522e7e

File tree

9 files changed

+70
-54
lines changed

9 files changed

+70
-54
lines changed

pandas/_libs/arrays.pyi

+6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ from typing import Sequence
33
import numpy as np
44

55
from pandas._typing import (
6+
AxisInt,
67
DtypeObj,
8+
Self,
79
Shape,
810
)
911

@@ -32,3 +34,7 @@ class NDArrayBacked:
3234
def ravel(self, order=...): ...
3335
@property
3436
def T(self): ...
37+
@classmethod
38+
def _concat_same_type(
39+
cls, to_concat: Sequence[Self], axis: AxisInt = ...
40+
) -> Self: ...

pandas/_libs/arrays.pyx

+7
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,10 @@ cdef class NDArrayBacked:
182182
def transpose(self, *axes):
183183
res_values = self._ndarray.transpose(*axes)
184184
return self._from_backing_data(res_values)
185+
186+
@classmethod
187+
def _concat_same_type(cls, to_concat, axis=0):
188+
# NB: We are assuming at this point that dtypes all match
189+
new_values = [obj._ndarray for obj in to_concat]
190+
new_arr = cnp.PyArray_Concatenate(new_values, axis)
191+
return to_concat[0]._from_backing_data(new_arr)

pandas/core/arrays/_mixins.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,11 @@ def _concat_same_type(
224224
to_concat: Sequence[Self],
225225
axis: AxisInt = 0,
226226
) -> Self:
227-
dtypes = {str(x.dtype) for x in to_concat}
228-
if len(dtypes) != 1:
229-
raise ValueError("to_concat must have the same dtype (tz)", dtypes)
227+
if not lib.dtypes_all_equal([x.dtype for x in to_concat]):
228+
dtypes = {str(x.dtype) for x in to_concat}
229+
raise ValueError("to_concat must have the same dtype", dtypes)
230230

231-
new_values = [x._ndarray for x in to_concat]
232-
new_arr = np.concatenate(new_values, axis=axis)
233-
return to_concat[0]._from_backing_data(new_arr)
231+
return super()._concat_same_type(to_concat, axis=axis)
234232

235233
@doc(ExtensionArray.searchsorted)
236234
def searchsorted(

pandas/core/arrays/datetimelike.py

+34-37
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ def new_meth(self, *args, **kwargs):
182182
return cast(F, new_meth)
183183

184184

185-
class DatetimeLikeArrayMixin(OpsMixin, NDArrayBackedExtensionArray):
185+
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
186+
# incompatible with definition in base class "ExtensionArray"
187+
class DatetimeLikeArrayMixin( # type: ignore[misc]
188+
OpsMixin, NDArrayBackedExtensionArray
189+
):
186190
"""
187191
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray
188192
@@ -505,42 +509,6 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
505509
# are present in this file.
506510
return super().view(dtype)
507511

508-
# ------------------------------------------------------------------
509-
# ExtensionArray Interface
510-
511-
@classmethod
512-
def _concat_same_type(
513-
cls,
514-
to_concat: Sequence[Self],
515-
axis: AxisInt = 0,
516-
) -> Self:
517-
new_obj = super()._concat_same_type(to_concat, axis)
518-
519-
obj = to_concat[0]
520-
dtype = obj.dtype
521-
522-
new_freq = None
523-
if isinstance(dtype, PeriodDtype):
524-
new_freq = obj.freq
525-
elif axis == 0:
526-
# GH 3232: If the concat result is evenly spaced, we can retain the
527-
# original frequency
528-
to_concat = [x for x in to_concat if len(x)]
529-
530-
if obj.freq is not None and all(x.freq == obj.freq for x in to_concat):
531-
pairs = zip(to_concat[:-1], to_concat[1:])
532-
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
533-
new_freq = obj.freq
534-
535-
new_obj._freq = new_freq
536-
return new_obj
537-
538-
def copy(self, order: str = "C") -> Self:
539-
# error: Unexpected keyword argument "order" for "copy"
540-
new_obj = super().copy(order=order) # type: ignore[call-arg]
541-
new_obj._freq = self.freq
542-
return new_obj
543-
544512
# ------------------------------------------------------------------
545513
# Validation Methods
546514
# TODO: try to de-duplicate these, ensure identical behavior
@@ -2085,6 +2053,7 @@ def _with_freq(self, freq) -> Self:
20852053
return arr
20862054

20872055
# --------------------------------------------------------------
2056+
# ExtensionArray Interface
20882057

20892058
def factorize(
20902059
self,
@@ -2102,6 +2071,34 @@ def factorize(
21022071
# FIXME: shouldn't get here; we are ignoring sort
21032072
return super().factorize(use_na_sentinel=use_na_sentinel)
21042073

2074+
@classmethod
2075+
def _concat_same_type(
2076+
cls,
2077+
to_concat: Sequence[Self],
2078+
axis: AxisInt = 0,
2079+
) -> Self:
2080+
new_obj = super()._concat_same_type(to_concat, axis)
2081+
2082+
obj = to_concat[0]
2083+
2084+
if axis == 0:
2085+
# GH 3232: If the concat result is evenly spaced, we can retain the
2086+
# original frequency
2087+
to_concat = [x for x in to_concat if len(x)]
2088+
2089+
if obj.freq is not None and all(x.freq == obj.freq for x in to_concat):
2090+
pairs = zip(to_concat[:-1], to_concat[1:])
2091+
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
2092+
new_freq = obj.freq
2093+
new_obj._freq = new_freq
2094+
return new_obj
2095+
2096+
def copy(self, order: str = "C") -> Self:
2097+
# error: Unexpected keyword argument "order" for "copy"
2098+
new_obj = super().copy(order=order) # type: ignore[call-arg]
2099+
new_obj._freq = self.freq
2100+
return new_obj
2101+
21052102

21062103
# -------------------------------------------------------------------
21072104
# Shared Constructor Helpers

pandas/core/arrays/datetimes.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def f(self):
151151
return property(f)
152152

153153

154-
class DatetimeArray(dtl.TimelikeOps, dtl.DatelikeOps):
154+
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
155+
# incompatible with definition in base class "ExtensionArray"
156+
class DatetimeArray(dtl.TimelikeOps, dtl.DatelikeOps): # type: ignore[misc]
155157
"""
156158
Pandas ExtensionArray for tz-naive or tz-aware datetime data.
157159

pandas/core/arrays/numpy_.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
)
4141

4242

43-
class PandasArray(
43+
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
44+
# incompatible with definition in base class "ExtensionArray"
45+
class PandasArray( # type: ignore[misc]
4446
OpsMixin,
4547
NDArrayBackedExtensionArray,
4648
ObjectStringArrayMixin,

pandas/core/arrays/period.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def f(self):
109109
return property(f)
110110

111111

112-
class PeriodArray(dtl.DatelikeOps, libperiod.PeriodMixin):
112+
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
113+
# incompatible with definition in base class "ExtensionArray"
114+
class PeriodArray(dtl.DatelikeOps, libperiod.PeriodMixin): # type: ignore[misc]
113115
"""
114116
Pandas ExtensionArray for storing Period data.
115117
@@ -263,7 +265,10 @@ def _from_sequence(
263265
validate_dtype_freq(scalars.dtype, freq)
264266
if copy:
265267
scalars = scalars.copy()
266-
return scalars
268+
# error: Incompatible return value type
269+
# (got "Union[Sequence[Optional[Period]], Union[Union[ExtensionArray,
270+
# ndarray[Any, Any]], Index, Series]]", expected "PeriodArray")
271+
return scalars # type: ignore[return-value]
267272

268273
periods = np.asarray(scalars, dtype=object)
269274

pandas/core/arrays/string_.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def tolist(self):
228228
return list(self.to_numpy())
229229

230230

231-
class StringArray(BaseStringArray, PandasArray):
231+
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
232+
# incompatible with definition in base class "ExtensionArray"
233+
class StringArray(BaseStringArray, PandasArray): # type: ignore[misc]
232234
"""
233235
Extension array for string data.
234236

pandas/core/dtypes/dtypes.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -956,12 +956,9 @@ def __eq__(self, other: Any) -> bool:
956956
elif isinstance(other, PeriodDtype):
957957
# For freqs that can be held by a PeriodDtype, this check is
958958
# equivalent to (and much faster than) self.freq == other.freq
959-
sfreq = self.freq
960-
ofreq = other.freq
961-
return (
962-
sfreq.n == ofreq.n
963-
and sfreq._period_dtype_code == ofreq._period_dtype_code
964-
)
959+
sfreq = self._freq
960+
ofreq = other._freq
961+
return sfreq.n == ofreq.n and self._dtype_code == other._dtype_code
965962

966963
return False
967964

0 commit comments

Comments
 (0)