Skip to content

REF: share code for set-like ops in DTI/TDI/PI #31335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 9, 2020
27 changes: 23 additions & 4 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,31 @@ def take(self, indices, allow_fill=False, fill_value=None):

@classmethod
def _concat_same_type(cls, to_concat):
dtypes = {x.dtype for x in to_concat}
assert len(dtypes) == 1
dtype = list(dtypes)[0]

# do not pass tz to set because tzlocal cannot be hashed
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
raise ValueError("to_concat must have the same dtype (tz)", dtypes)

obj = to_concat[0]
dtype = obj.dtype

values = np.concatenate([x.asi8 for x in to_concat])
return cls(values, dtype=dtype)

if is_period_dtype(to_concat[0].dtype):
new_freq = obj.freq
else:
# GH 3232: If the concat result is evenly spaced, we can retain the
# original frequency
new_freq = None
to_concat = [x for x in to_concat if len(x)]

if obj.freq is not None and all(x.freq == obj.freq for x in to_concat):
pairs = zip(to_concat[:-1], to_concat[1:])
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
new_freq = obj.freq

return cls._simple_new(values, dtype=dtype, freq=new_freq)

def copy(self):
values = self.asi8.copy()
Expand Down
12 changes: 0 additions & 12 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np

from pandas._libs import NaT, iNaT, join as libjoin, lib
from pandas._libs.algos import unique_deltas
from pandas._libs.tslibs import timezones
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
Expand Down Expand Up @@ -503,19 +502,8 @@ def _concat_same_dtype(self, to_concat, name):
Concatenate to_concat which has the same class.
"""

# do not pass tz to set because tzlocal cannot be hashed
if len({str(x.dtype) for x in to_concat}) != 1:
raise ValueError("to_concat must have the same tz")

new_data = type(self._data)._concat_same_type(to_concat)

if not is_period_dtype(self.dtype):
# GH 3232: If the concat result is evenly spaced, we can retain the
# original frequency
is_diff_evenly_spaced = len(unique_deltas(new_data.asi8)) == 1
if is_diff_evenly_spaced:
new_data._freq = self.freq

return self._simple_new(new_data, name=name)

def shift(self, periods=1, freq=None):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def union_many(self, others):

def _wrap_setop_result(self, other, result):
name = get_op_result_name(self, other)
return self._shallow_copy(result, name=name, freq=None, tz=self.tz)
return self._shallow_copy(result, name=name, freq=None)

# --------------------------------------------------------------------

Expand Down
43 changes: 21 additions & 22 deletions pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,28 @@ def _assert_can_do_setop(self, other):
if isinstance(other, PeriodIndex) and self.freq != other.freq:
raise raise_on_incompatible(self, other)

def intersection(self, other, sort=False):
def _setop(self, other, sort, opname: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eventually if you can type / doc-string this

"""
Perform a set operation by dispatching to the Int64Index implementation.
"""
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
res_name = get_op_result_name(self, other)
other = ensure_index(other)

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = getattr(i8self, opname)(i8other, sort=sort)

parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype)
result = type(self)._simple_new(parr, name=res_name)
return result

def intersection(self, other, sort=False):
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other = ensure_index(other)

if self.equals(other):
return self._get_reconciled_name_object(other)

Expand All @@ -773,35 +789,24 @@ def intersection(self, other, sort=False):
other = other.astype("O")
return this.intersection(other, sort=sort)

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = i8self.intersection(i8other, sort=sort)

result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
return result
return self._setop(other, sort, opname="intersection")

def difference(self, other, sort=None):
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
res_name = get_op_result_name(self, other)
other = ensure_index(other)

if self.equals(other):
# pass an empty PeriodArray with the appropriate dtype
return self._shallow_copy(self._data[:0])
return type(self)._simple_new(self._data[:0], name=self.name)

if is_object_dtype(other):
return self.astype(object).difference(other).astype(self.dtype)

elif not is_dtype_equal(self.dtype, other.dtype):
return self

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = i8self.difference(i8other, sort=sort)

result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
return result
return self._setop(other, sort, opname="difference")

def _union(self, other, sort):
if not len(other) or self.equals(other) or not len(self):
Expand All @@ -815,13 +820,7 @@ def _union(self, other, sort):
other = other.astype("O")
return this._union(other, sort=sort)

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = i8self._union(i8other, sort=sort)

res_name = get_op_result_name(self, other)
result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
return result
return self._setop(other, sort, opname="_union")

# ------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_concat_same_type_invalid(self, datetime_index):
else:
other = arr.tz_localize(None)

with pytest.raises(AssertionError):
with pytest.raises(ValueError, match="to_concat must have the same"):
arr._concat_same_type([arr, other])

def test_concat_same_type_different_freq(self):
Expand Down