Skip to content

Commit 7d1f825

Browse files
authored
REF: share code for set-like ops in DTI/TDI/PI (#31335)
1 parent f7e2b74 commit 7d1f825

File tree

4 files changed

+46
-40
lines changed

4 files changed

+46
-40
lines changed

pandas/core/arrays/datetimelike.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -702,12 +702,31 @@ def take(self, indices, allow_fill=False, fill_value=None):
702702

703703
@classmethod
704704
def _concat_same_type(cls, to_concat):
705-
dtypes = {x.dtype for x in to_concat}
706-
assert len(dtypes) == 1
707-
dtype = list(dtypes)[0]
705+
706+
# do not pass tz to set because tzlocal cannot be hashed
707+
dtypes = {str(x.dtype) for x in to_concat}
708+
if len(dtypes) != 1:
709+
raise ValueError("to_concat must have the same dtype (tz)", dtypes)
710+
711+
obj = to_concat[0]
712+
dtype = obj.dtype
708713

709714
values = np.concatenate([x.asi8 for x in to_concat])
710-
return cls(values, dtype=dtype)
715+
716+
if is_period_dtype(to_concat[0].dtype):
717+
new_freq = obj.freq
718+
else:
719+
# GH 3232: If the concat result is evenly spaced, we can retain the
720+
# original frequency
721+
new_freq = None
722+
to_concat = [x for x in to_concat if len(x)]
723+
724+
if obj.freq is not None and all(x.freq == obj.freq for x in to_concat):
725+
pairs = zip(to_concat[:-1], to_concat[1:])
726+
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
727+
new_freq = obj.freq
728+
729+
return cls._simple_new(values, dtype=dtype, freq=new_freq)
711730

712731
def copy(self):
713732
values = self.asi8.copy()

pandas/core/indexes/datetimelike.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77

88
from pandas._libs import NaT, iNaT, join as libjoin, lib
9-
from pandas._libs.algos import unique_deltas
109
from pandas._libs.tslibs import timezones
1110
from pandas.compat.numpy import function as nv
1211
from pandas.errors import AbstractMethodError
@@ -515,20 +514,9 @@ def _concat_same_dtype(self, to_concat, name):
515514
Concatenate to_concat which has the same class.
516515
"""
517516

518-
# do not pass tz to set because tzlocal cannot be hashed
519-
if len({str(x.dtype) for x in to_concat}) != 1:
520-
raise ValueError("to_concat must have the same tz")
521-
522517
new_data = type(self._data)._concat_same_type(to_concat)
523518

524-
if not is_period_dtype(self.dtype):
525-
# GH 3232: If the concat result is evenly spaced, we can retain the
526-
# original frequency
527-
is_diff_evenly_spaced = len(unique_deltas(new_data.asi8)) == 1
528-
if is_diff_evenly_spaced:
529-
new_data._freq = self.freq
530-
531-
return type(self)._simple_new(new_data, name=name)
519+
return self._simple_new(new_data, name=name)
532520

533521
def shift(self, periods=1, freq=None):
534522
"""

pandas/core/indexes/period.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -697,12 +697,28 @@ def _assert_can_do_setop(self, other):
697697
if isinstance(other, PeriodIndex) and self.freq != other.freq:
698698
raise raise_on_incompatible(self, other)
699699

700-
def intersection(self, other, sort=False):
700+
def _setop(self, other, sort, opname: str):
701+
"""
702+
Perform a set operation by dispatching to the Int64Index implementation.
703+
"""
701704
self._validate_sort_keyword(sort)
702705
self._assert_can_do_setop(other)
703706
res_name = get_op_result_name(self, other)
704707
other = ensure_index(other)
705708

709+
i8self = Int64Index._simple_new(self.asi8)
710+
i8other = Int64Index._simple_new(other.asi8)
711+
i8result = getattr(i8self, opname)(i8other, sort=sort)
712+
713+
parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype)
714+
result = type(self)._simple_new(parr, name=res_name)
715+
return result
716+
717+
def intersection(self, other, sort=False):
718+
self._validate_sort_keyword(sort)
719+
self._assert_can_do_setop(other)
720+
other = ensure_index(other)
721+
706722
if self.equals(other):
707723
return self._get_reconciled_name_object(other)
708724

@@ -712,35 +728,24 @@ def intersection(self, other, sort=False):
712728
other = other.astype("O")
713729
return this.intersection(other, sort=sort)
714730

715-
i8self = Int64Index._simple_new(self.asi8)
716-
i8other = Int64Index._simple_new(other.asi8)
717-
i8result = i8self.intersection(i8other, sort=sort)
718-
719-
result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
720-
return result
731+
return self._setop(other, sort, opname="intersection")
721732

722733
def difference(self, other, sort=None):
723734
self._validate_sort_keyword(sort)
724735
self._assert_can_do_setop(other)
725-
res_name = get_op_result_name(self, other)
726736
other = ensure_index(other)
727737

728738
if self.equals(other):
729739
# pass an empty PeriodArray with the appropriate dtype
730-
return self._shallow_copy(self._data[:0])
740+
return type(self)._simple_new(self._data[:0], name=self.name)
731741

732742
if is_object_dtype(other):
733743
return self.astype(object).difference(other).astype(self.dtype)
734744

735745
elif not is_dtype_equal(self.dtype, other.dtype):
736746
return self
737747

738-
i8self = Int64Index._simple_new(self.asi8)
739-
i8other = Int64Index._simple_new(other.asi8)
740-
i8result = i8self.difference(i8other, sort=sort)
741-
742-
result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
743-
return result
748+
return self._setop(other, sort, opname="difference")
744749

745750
def _union(self, other, sort):
746751
if not len(other) or self.equals(other) or not len(self):
@@ -754,13 +759,7 @@ def _union(self, other, sort):
754759
other = other.astype("O")
755760
return this._union(other, sort=sort)
756761

757-
i8self = Int64Index._simple_new(self.asi8)
758-
i8other = Int64Index._simple_new(other.asi8)
759-
i8result = i8self._union(i8other, sort=sort)
760-
761-
res_name = get_op_result_name(self, other)
762-
result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
763-
return result
762+
return self._setop(other, sort, opname="_union")
764763

765764
# ------------------------------------------------------------------------
766765

pandas/tests/arrays/test_datetimelike.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_concat_same_type_invalid(self, datetime_index):
463463
else:
464464
other = arr.tz_localize(None)
465465

466-
with pytest.raises(AssertionError):
466+
with pytest.raises(ValueError, match="to_concat must have the same"):
467467
arr._concat_same_type([arr, other])
468468

469469
def test_concat_same_type_different_freq(self):

0 commit comments

Comments
 (0)