Skip to content

Commit eb1ae6b

Browse files
committed
Merge pull request #7458 from sinhrks/intersection
BUG: DTI.intersection doesnt preserve tz
2 parents 8cfff98 + e592f1d commit eb1ae6b

File tree

7 files changed

+146
-45
lines changed

7 files changed

+146
-45
lines changed

doc/source/v0.14.1.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,5 @@ Bug Fixes
233233

234234

235235

236-
236+
- Bug in non-monotonic ``Index.union`` may preserve ``name`` incorrectly (:issue:`7458`)
237+
- Bug in ``DatetimeIndex.intersection`` doesn't preserve timezone (:issue:`4690`)

pandas/core/index.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,8 @@ def take(self, indexer, axis=0):
777777
"""
778778
indexer = com._ensure_platform_int(indexer)
779779
taken = self.view(np.ndarray).take(indexer)
780-
return self._constructor(taken, name=self.name)
780+
return self._simple_new(taken, name=self.name, freq=None,
781+
tz=getattr(self, 'tz', None))
781782

782783
def format(self, name=False, formatter=None, **kwargs):
783784
"""
@@ -1075,7 +1076,10 @@ def intersection(self, other):
10751076
# duplicates
10761077
indexer = self.get_indexer_non_unique(other.values)[0].unique()
10771078

1078-
return self.take(indexer)
1079+
taken = self.take(indexer)
1080+
if self.name != other.name:
1081+
taken.name = None
1082+
return taken
10791083

10801084
def diff(self, other):
10811085
"""

pandas/tests/test_index.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from datetime import datetime, timedelta
44
from pandas.compat import range, lrange, lzip, u, zip
5-
import sys
65
import operator
76
import pickle
87
import re
@@ -447,6 +446,33 @@ def test_intersection(self):
447446
# non-iterable input
448447
assertRaisesRegexp(TypeError, "iterable", first.intersection, 0.5)
449448

449+
idx1 = Index([1, 2, 3, 4, 5], name='idx')
450+
# if target has the same name, it is preserved
451+
idx2 = Index([3, 4, 5, 6, 7], name='idx')
452+
expected2 = Index([3, 4, 5], name='idx')
453+
result2 = idx1.intersection(idx2)
454+
self.assertTrue(result2.equals(expected2))
455+
self.assertEqual(result2.name, expected2.name)
456+
457+
# if target name is different, it will be reset
458+
idx3 = Index([3, 4, 5, 6, 7], name='other')
459+
expected3 = Index([3, 4, 5], name=None)
460+
result3 = idx1.intersection(idx3)
461+
self.assertTrue(result3.equals(expected3))
462+
self.assertEqual(result3.name, expected3.name)
463+
464+
# non monotonic
465+
idx1 = Index([5, 3, 2, 4, 1], name='idx')
466+
idx2 = Index([4, 7, 6, 5, 3], name='idx')
467+
result2 = idx1.intersection(idx2)
468+
self.assertTrue(tm.equalContents(result2, expected2))
469+
self.assertEqual(result2.name, expected2.name)
470+
471+
idx3 = Index([4, 7, 6, 5, 3], name='other')
472+
result3 = idx1.intersection(idx3)
473+
self.assertTrue(tm.equalContents(result3, expected3))
474+
self.assertEqual(result3.name, expected3.name)
475+
450476
def test_union(self):
451477
first = self.strIndex[5:20]
452478
second = self.strIndex[:10]

pandas/tseries/index.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -900,9 +900,7 @@ def take(self, indices, axis=0):
900900
maybe_slice = lib.maybe_indices_to_slice(com._ensure_int64(indices))
901901
if isinstance(maybe_slice, slice):
902902
return self[maybe_slice]
903-
indices = com._ensure_platform_int(indices)
904-
taken = self.values.take(indices, axis=axis)
905-
return self._simple_new(taken, self.name, None, self.tz)
903+
return super(DatetimeIndex, self).take(indices, axis)
906904

907905
def unique(self):
908906
"""
@@ -1125,6 +1123,12 @@ def __array_finalize__(self, obj):
11251123
self.name = getattr(obj, 'name', None)
11261124
self._reset_identity()
11271125

1126+
def _wrap_union_result(self, other, result):
1127+
name = self.name if self.name == other.name else None
1128+
if self.tz != other.tz:
1129+
raise ValueError('Passed item and index have different timezone')
1130+
return self._simple_new(result, name=name, freq=None, tz=self.tz)
1131+
11281132
def intersection(self, other):
11291133
"""
11301134
Specialized intersection for DatetimeIndex objects. May be much faster

pandas/tseries/period.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1133,10 +1133,7 @@ def take(self, indices, axis=None):
11331133
"""
11341134
indices = com._ensure_platform_int(indices)
11351135
taken = self.values.take(indices, axis=axis)
1136-
taken = taken.view(PeriodIndex)
1137-
taken.freq = self.freq
1138-
taken.name = self.name
1139-
return taken
1136+
return self._simple_new(taken, self.name, freq=self.freq)
11401137

11411138
def append(self, other):
11421139
"""

pandas/tseries/tests/test_period.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -2070,14 +2070,19 @@ def test_iteration(self):
20702070
self.assertEqual(result[0].freq, index.freq)
20712071

20722072
def test_take(self):
2073-
index = PeriodIndex(start='1/1/10', end='12/31/12', freq='D')
2073+
index = PeriodIndex(start='1/1/10', end='12/31/12', freq='D', name='idx')
2074+
expected = PeriodIndex([datetime(2010, 1, 6), datetime(2010, 1, 7),
2075+
datetime(2010, 1, 9), datetime(2010, 1, 13)],
2076+
freq='D', name='idx')
20742077

2075-
taken = index.take([5, 6, 8, 12])
2078+
taken1 = index.take([5, 6, 8, 12])
20762079
taken2 = index[[5, 6, 8, 12]]
2077-
tm.assert_isinstance(taken, PeriodIndex)
2078-
self.assertEqual(taken.freq, index.freq)
2079-
tm.assert_isinstance(taken2, PeriodIndex)
2080-
self.assertEqual(taken2.freq, index.freq)
2080+
2081+
for taken in [taken1, taken2]:
2082+
self.assertTrue(taken.equals(expected))
2083+
tm.assert_isinstance(taken, PeriodIndex)
2084+
self.assertEqual(taken.freq, index.freq)
2085+
self.assertEqual(taken.name, expected.name)
20812086

20822087
def test_joins(self):
20832088
index = period_range('1/1/2000', '1/20/2000', freq='D')

pandas/tseries/tests/test_timeseries.py

+92-28
Original file line numberDiff line numberDiff line change
@@ -2467,6 +2467,25 @@ def test_delete_slice(self):
24672467
self.assertEqual(result.name, expected.name)
24682468
self.assertEqual(result.freq, expected.freq)
24692469

2470+
def test_take(self):
2471+
dates = [datetime(2010, 1, 6), datetime(2010, 1, 7),
2472+
datetime(2010, 1, 9), datetime(2010, 1, 13)]
2473+
2474+
for tz in [None, 'US/Eastern', 'Asia/Tokyo']:
2475+
idx = DatetimeIndex(start='1/1/10', end='12/31/12',
2476+
freq='D', tz=tz, name='idx')
2477+
expected = DatetimeIndex(dates, freq=None, name='idx', tz=tz)
2478+
2479+
taken1 = idx.take([5, 6, 8, 12])
2480+
taken2 = idx[[5, 6, 8, 12]]
2481+
2482+
for taken in [taken1, taken2]:
2483+
self.assertTrue(taken.equals(expected))
2484+
tm.assert_isinstance(taken, DatetimeIndex)
2485+
self.assertIsNone(taken.freq)
2486+
self.assertEqual(taken.tz, expected.tz)
2487+
self.assertEqual(taken.name, expected.name)
2488+
24702489
def test_map_bug_1677(self):
24712490
index = DatetimeIndex(['2012-04-25 09:30:00.393000'])
24722491
f = index.asof
@@ -3035,14 +3054,46 @@ def test_union(self):
30353054
self.assertEqual(df.index.values.dtype, np.dtype('M8[ns]'))
30363055

30373056
def test_intersection(self):
3038-
rng = date_range('6/1/2000', '6/15/2000', freq='D')
3039-
rng = rng.delete(5)
3040-
3041-
rng2 = date_range('5/15/2000', '6/20/2000', freq='D')
3042-
rng2 = DatetimeIndex(rng2.values)
3043-
3044-
result = rng.intersection(rng2)
3045-
self.assertTrue(result.equals(rng))
3057+
# GH 4690 (with tz)
3058+
for tz in [None, 'Asia/Tokyo']:
3059+
rng = date_range('6/1/2000', '6/30/2000', freq='D', name='idx')
3060+
3061+
# if target has the same name, it is preserved
3062+
rng2 = date_range('5/15/2000', '6/20/2000', freq='D', name='idx')
3063+
expected2 = date_range('6/1/2000', '6/20/2000', freq='D', name='idx')
3064+
3065+
# if target name is different, it will be reset
3066+
rng3 = date_range('5/15/2000', '6/20/2000', freq='D', name='other')
3067+
expected3 = date_range('6/1/2000', '6/20/2000', freq='D', name=None)
3068+
3069+
result2 = rng.intersection(rng2)
3070+
result3 = rng.intersection(rng3)
3071+
for (result, expected) in [(result2, expected2), (result3, expected3)]:
3072+
self.assertTrue(result.equals(expected))
3073+
self.assertEqual(result.name, expected.name)
3074+
self.assertEqual(result.freq, expected.freq)
3075+
self.assertEqual(result.tz, expected.tz)
3076+
3077+
# non-monotonic
3078+
rng = DatetimeIndex(['2011-01-05', '2011-01-04', '2011-01-02', '2011-01-03'],
3079+
tz=tz, name='idx')
3080+
3081+
rng2 = DatetimeIndex(['2011-01-04', '2011-01-02', '2011-02-02', '2011-02-03'],
3082+
tz=tz, name='idx')
3083+
expected2 = DatetimeIndex(['2011-01-04', '2011-01-02'], tz=tz, name='idx')
3084+
3085+
rng3 = DatetimeIndex(['2011-01-04', '2011-01-02', '2011-02-02', '2011-02-03'],
3086+
tz=tz, name='other')
3087+
expected3 = DatetimeIndex(['2011-01-04', '2011-01-02'], tz=tz, name=None)
3088+
3089+
result2 = rng.intersection(rng2)
3090+
result3 = rng.intersection(rng3)
3091+
for (result, expected) in [(result2, expected2), (result3, expected3)]:
3092+
print(result, expected)
3093+
self.assertTrue(result.equals(expected))
3094+
self.assertEqual(result.name, expected.name)
3095+
self.assertIsNone(result.freq)
3096+
self.assertEqual(result.tz, expected.tz)
30463097

30473098
# empty same freq GH2129
30483099
rng = date_range('6/1/2000', '6/15/2000', freq='T')
@@ -3571,26 +3622,39 @@ def test_shift(self):
35713622
self.assertRaises(ValueError, idx.shift, 1)
35723623

35733624
def test_setops_preserve_freq(self):
3574-
rng = date_range('1/1/2000', '1/1/2002')
3575-
3576-
result = rng[:50].union(rng[50:100])
3577-
self.assertEqual(result.freq, rng.freq)
3578-
3579-
result = rng[:50].union(rng[30:100])
3580-
self.assertEqual(result.freq, rng.freq)
3581-
3582-
result = rng[:50].union(rng[60:100])
3583-
self.assertIsNone(result.freq)
3584-
3585-
result = rng[:50].intersection(rng[25:75])
3586-
self.assertEqual(result.freqstr, 'D')
3587-
3588-
nofreq = DatetimeIndex(list(rng[25:75]))
3589-
result = rng[:50].union(nofreq)
3590-
self.assertEqual(result.freq, rng.freq)
3591-
3592-
result = rng[:50].intersection(nofreq)
3593-
self.assertEqual(result.freq, rng.freq)
3625+
for tz in [None, 'Asia/Tokyo', 'US/Eastern']:
3626+
rng = date_range('1/1/2000', '1/1/2002', name='idx', tz=tz)
3627+
3628+
result = rng[:50].union(rng[50:100])
3629+
self.assertEqual(result.name, rng.name)
3630+
self.assertEqual(result.freq, rng.freq)
3631+
self.assertEqual(result.tz, rng.tz)
3632+
3633+
result = rng[:50].union(rng[30:100])
3634+
self.assertEqual(result.name, rng.name)
3635+
self.assertEqual(result.freq, rng.freq)
3636+
self.assertEqual(result.tz, rng.tz)
3637+
3638+
result = rng[:50].union(rng[60:100])
3639+
self.assertEqual(result.name, rng.name)
3640+
self.assertIsNone(result.freq)
3641+
self.assertEqual(result.tz, rng.tz)
3642+
3643+
result = rng[:50].intersection(rng[25:75])
3644+
self.assertEqual(result.name, rng.name)
3645+
self.assertEqual(result.freqstr, 'D')
3646+
self.assertEqual(result.tz, rng.tz)
3647+
3648+
nofreq = DatetimeIndex(list(rng[25:75]), name='other')
3649+
result = rng[:50].union(nofreq)
3650+
self.assertIsNone(result.name)
3651+
self.assertEqual(result.freq, rng.freq)
3652+
self.assertEqual(result.tz, rng.tz)
3653+
3654+
result = rng[:50].intersection(nofreq)
3655+
self.assertIsNone(result.name)
3656+
self.assertEqual(result.freq, rng.freq)
3657+
self.assertEqual(result.tz, rng.tz)
35943658

35953659
def test_min_max(self):
35963660
rng = date_range('1/1/2000', '12/31/2000')

0 commit comments

Comments
 (0)