From 518460606de165fb2ea6df48cd9fa825a5e82dde Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Sun, 26 Jan 2020 20:39:18 -0800 Subject: [PATCH 1/2] REF: share code for set-like ops in DTI/TDI/PI --- pandas/core/arrays/datetimelike.py | 27 ++++++++++++--- pandas/core/indexes/datetimelike.py | 11 ------ pandas/core/indexes/datetimes.py | 2 +- pandas/core/indexes/period.py | 43 ++++++++++++------------ pandas/tests/arrays/test_datetimelike.py | 2 +- 5 files changed, 46 insertions(+), 39 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 70637026c278d..318f6c2f73f3f 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -691,12 +691,31 @@ def take(self, indices, allow_fill=False, fill_value=None): @classmethod def _concat_same_type(cls, to_concat): - dtypes = {x.dtype for x in to_concat} - assert len(dtypes) == 1 - dtype = list(dtypes)[0] + + # do not pass tz to set because tzlocal cannot be hashed + dtypes = {str(x.dtype) for x in to_concat} + if len(dtypes) != 1: + raise ValueError("to_concat must have the same dtype (tz)", dtypes) + + obj = to_concat[0] + dtype = obj.dtype values = np.concatenate([x.asi8 for x in to_concat]) - return cls(values, dtype=dtype) + + if is_period_dtype(to_concat[0].dtype): + new_freq = obj.freq + else: + # GH 3232: If the concat result is evenly spaced, we can retain the + # original frequency + new_freq = None + to_concat = [x for x in to_concat if len(x)] + + if obj.freq is not None and all(x.freq == obj.freq for x in to_concat): + pairs = zip(to_concat[:-1], to_concat[1:]) + if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs): + new_freq = obj.freq + + return cls._simple_new(values, dtype=dtype, freq=new_freq) def copy(self): values = self.asi8.copy() diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index b87dd0f02252f..5a1f427780d17 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -503,19 +503,8 @@ def _concat_same_dtype(self, to_concat, name): Concatenate to_concat which has the same class. """ - # do not pass tz to set because tzlocal cannot be hashed - if len({str(x.dtype) for x in to_concat}) != 1: - raise ValueError("to_concat must have the same tz") - new_data = type(self._data)._concat_same_type(to_concat) - if not is_period_dtype(self.dtype): - # GH 3232: If the concat result is evenly spaced, we can retain the - # original frequency - is_diff_evenly_spaced = len(unique_deltas(new_data.asi8)) == 1 - if is_diff_evenly_spaced: - new_data._freq = self.freq - return self._simple_new(new_data, name=name) def shift(self, periods=1, freq=None): diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index b269239ed10ac..3afd1ff35806d 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -377,7 +377,7 @@ def union_many(self, others): def _wrap_setop_result(self, other, result): name = get_op_result_name(self, other) - return self._shallow_copy(result, name=name, freq=None, tz=self.tz) + return self._shallow_copy(result, name=name, freq=None) # -------------------------------------------------------------------- diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index fe6c1ba808f9a..6d85cd74c6765 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -758,12 +758,28 @@ def _assert_can_do_setop(self, other): if isinstance(other, PeriodIndex) and self.freq != other.freq: raise raise_on_incompatible(self, other) - def intersection(self, other, sort=False): + def _setop(self, other, sort, opname: str): + """ + Perform a set operation by dispatching to the Int64Index implementation. + """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) res_name = get_op_result_name(self, other) other = ensure_index(other) + i8self = Int64Index._simple_new(self.asi8) + i8other = Int64Index._simple_new(other.asi8) + i8result = getattr(i8self, opname)(i8other, sort=sort) + + parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype) + result = type(self)._simple_new(parr, name=res_name) + return result + + def intersection(self, other, sort=False): + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other = ensure_index(other) + if self.equals(other): return self._get_reconciled_name_object(other) @@ -773,22 +789,16 @@ def intersection(self, other, sort=False): other = other.astype("O") return this.intersection(other, sort=sort) - i8self = Int64Index._simple_new(self.asi8) - i8other = Int64Index._simple_new(other.asi8) - i8result = i8self.intersection(i8other, sort=sort) - - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) - return result + return self._setop(other, sort, opname="intersection") def difference(self, other, sort=None): self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - res_name = get_op_result_name(self, other) other = ensure_index(other) if self.equals(other): # pass an empty PeriodArray with the appropriate dtype - return self._shallow_copy(self._data[:0]) + return type(self)._simple_new(self._data[:0], name=self.name) if is_object_dtype(other): return self.astype(object).difference(other).astype(self.dtype) @@ -796,12 +806,7 @@ def difference(self, other, sort=None): elif not is_dtype_equal(self.dtype, other.dtype): return self - i8self = Int64Index._simple_new(self.asi8) - i8other = Int64Index._simple_new(other.asi8) - i8result = i8self.difference(i8other, sort=sort) - - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) - return result + return self._setop(other, sort, opname="difference") def _union(self, other, sort): if not len(other) or self.equals(other) or not len(self): @@ -815,13 +820,7 @@ def _union(self, other, sort): other = other.astype("O") return this._union(other, sort=sort) - i8self = Int64Index._simple_new(self.asi8) - i8other = Int64Index._simple_new(other.asi8) - i8result = i8self._union(i8other, sort=sort) - - res_name = get_op_result_name(self, other) - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) - return result + return self._setop(other, sort, opname="_union") # ------------------------------------------------------------------------ diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 87b825c8c27bd..17818b6ce689f 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -463,7 +463,7 @@ def test_concat_same_type_invalid(self, datetime_index): else: other = arr.tz_localize(None) - with pytest.raises(AssertionError): + with pytest.raises(ValueError, match="to_concat must have the same"): arr._concat_same_type([arr, other]) def test_concat_same_type_different_freq(self): From c1c28ca136f5a33ba60453d0809610f0ff4efa28 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Mon, 27 Jan 2020 09:03:53 -0800 Subject: [PATCH 2/2] fixup unused import --- pandas/core/indexes/datetimelike.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 5a1f427780d17..ea35d75f5f677 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -7,7 +7,6 @@ import numpy as np from pandas._libs import NaT, iNaT, join as libjoin, lib -from pandas._libs.algos import unique_deltas from pandas._libs.tslibs import timezones from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError