diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 4f14ac2a14157..03c8e48c6e699 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -702,12 +702,31 @@ def take(self, indices, allow_fill=False, fill_value=None): @classmethod def _concat_same_type(cls, to_concat): - dtypes = {x.dtype for x in to_concat} - assert len(dtypes) == 1 - dtype = list(dtypes)[0] + + # do not pass tz to set because tzlocal cannot be hashed + dtypes = {str(x.dtype) for x in to_concat} + if len(dtypes) != 1: + raise ValueError("to_concat must have the same dtype (tz)", dtypes) + + obj = to_concat[0] + dtype = obj.dtype values = np.concatenate([x.asi8 for x in to_concat]) - return cls(values, dtype=dtype) + + if is_period_dtype(to_concat[0].dtype): + new_freq = obj.freq + else: + # GH 3232: If the concat result is evenly spaced, we can retain the + # original frequency + new_freq = None + to_concat = [x for x in to_concat if len(x)] + + if obj.freq is not None and all(x.freq == obj.freq for x in to_concat): + pairs = zip(to_concat[:-1], to_concat[1:]) + if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs): + new_freq = obj.freq + + return cls._simple_new(values, dtype=dtype, freq=new_freq) def copy(self): values = self.asi8.copy() diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index b143ff0aa9c02..2d46211c9544f 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -6,7 +6,6 @@ import numpy as np from pandas._libs import NaT, iNaT, join as libjoin, lib -from pandas._libs.algos import unique_deltas from pandas._libs.tslibs import timezones from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError @@ -515,20 +514,9 @@ def _concat_same_dtype(self, to_concat, name): Concatenate to_concat which has the same class. """ - # do not pass tz to set because tzlocal cannot be hashed - if len({str(x.dtype) for x in to_concat}) != 1: - raise ValueError("to_concat must have the same tz") - new_data = type(self._data)._concat_same_type(to_concat) - if not is_period_dtype(self.dtype): - # GH 3232: If the concat result is evenly spaced, we can retain the - # original frequency - is_diff_evenly_spaced = len(unique_deltas(new_data.asi8)) == 1 - if is_diff_evenly_spaced: - new_data._freq = self.freq - - return type(self)._simple_new(new_data, name=name) + return self._simple_new(new_data, name=name) def shift(self, periods=1, freq=None): """ diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 42f0a012902a3..bbec3e6f911c3 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -697,12 +697,28 @@ def _assert_can_do_setop(self, other): if isinstance(other, PeriodIndex) and self.freq != other.freq: raise raise_on_incompatible(self, other) - def intersection(self, other, sort=False): + def _setop(self, other, sort, opname: str): + """ + Perform a set operation by dispatching to the Int64Index implementation. + """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) res_name = get_op_result_name(self, other) other = ensure_index(other) + i8self = Int64Index._simple_new(self.asi8) + i8other = Int64Index._simple_new(other.asi8) + i8result = getattr(i8self, opname)(i8other, sort=sort) + + parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype) + result = type(self)._simple_new(parr, name=res_name) + return result + + def intersection(self, other, sort=False): + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other = ensure_index(other) + if self.equals(other): return self._get_reconciled_name_object(other) @@ -712,22 +728,16 @@ def intersection(self, other, sort=False): other = other.astype("O") return this.intersection(other, sort=sort) - i8self = Int64Index._simple_new(self.asi8) - i8other = Int64Index._simple_new(other.asi8) - i8result = i8self.intersection(i8other, sort=sort) - - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) - return result + return self._setop(other, sort, opname="intersection") def difference(self, other, sort=None): self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - res_name = get_op_result_name(self, other) other = ensure_index(other) if self.equals(other): # pass an empty PeriodArray with the appropriate dtype - return self._shallow_copy(self._data[:0]) + return type(self)._simple_new(self._data[:0], name=self.name) if is_object_dtype(other): return self.astype(object).difference(other).astype(self.dtype) @@ -735,12 +745,7 @@ def difference(self, other, sort=None): elif not is_dtype_equal(self.dtype, other.dtype): return self - i8self = Int64Index._simple_new(self.asi8) - i8other = Int64Index._simple_new(other.asi8) - i8result = i8self.difference(i8other, sort=sort) - - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) - return result + return self._setop(other, sort, opname="difference") def _union(self, other, sort): if not len(other) or self.equals(other) or not len(self): @@ -754,13 +759,7 @@ def _union(self, other, sort): other = other.astype("O") return this._union(other, sort=sort) - i8self = Int64Index._simple_new(self.asi8) - i8other = Int64Index._simple_new(other.asi8) - i8result = i8self._union(i8other, sort=sort) - - res_name = get_op_result_name(self, other) - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) - return result + return self._setop(other, sort, opname="_union") # ------------------------------------------------------------------------ diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 87b825c8c27bd..17818b6ce689f 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -463,7 +463,7 @@ def test_concat_same_type_invalid(self, datetime_index): else: other = arr.tz_localize(None) - with pytest.raises(AssertionError): + with pytest.raises(ValueError, match="to_concat must have the same"): arr._concat_same_type([arr, other]) def test_concat_same_type_different_freq(self):