Skip to content

Commit 2312f7e

Browse files
committed
BUG: Index.union cannot handle array-likes
1 parent 1a709c3 commit 2312f7e

File tree

6 files changed

+177
-35
lines changed

6 files changed

+177
-35
lines changed

doc/source/whatsnew/v0.17.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Bug Fixes
6464

6565

6666
- Bug where Panel.from_dict does not set dtype when specified (:issue:`10058`)
67+
- Bug in ``Index.union`` raises ``AttributeError`` when passing array-likes. (:issue:`10149`)
6768
- Bug in ``Timestamp``'s' ``microsecond``, ``quarter``, ``dayofyear``, ``week`` and ``daysinmonth`` properties return ``np.int`` type, not built-in ``int``. (:issue:`10050`)
6869
- Bug in ``NaT`` raises ``AttributeError`` when accessing to ``daysinmonth``, ``dayofweek`` properties. (:issue:`10096`)
6970

@@ -76,3 +77,4 @@ Bug Fixes
7677
- Bug in `Series.plot(label="LABEL")` not correctly setting the label (:issue:`10119`)
7778

7879

80+

pandas/core/index.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ def to_datetime(self, dayfirst=False):
580580
return DatetimeIndex(self.values)
581581

582582
def _assert_can_do_setop(self, other):
583+
if not com.is_list_like(other):
584+
raise TypeError('Input must be Index or array-like')
583585
return True
584586

585587
@property
@@ -1364,16 +1366,14 @@ def union(self, other):
13641366
-------
13651367
union : Index
13661368
"""
1367-
if not hasattr(other, '__iter__'):
1368-
raise TypeError('Input must be iterable.')
1369+
self._assert_can_do_setop(other)
1370+
other = _ensure_index(other)
13691371

13701372
if len(other) == 0 or self.equals(other):
13711373
return self
13721374

13731375
if len(self) == 0:
1374-
return _ensure_index(other)
1375-
1376-
self._assert_can_do_setop(other)
1376+
return other
13771377

13781378
if not is_dtype_equal(self.dtype,other.dtype):
13791379
this = self.astype('O')
@@ -1439,11 +1439,7 @@ def intersection(self, other):
14391439
-------
14401440
intersection : Index
14411441
"""
1442-
if not hasattr(other, '__iter__'):
1443-
raise TypeError('Input must be iterable!')
1444-
14451442
self._assert_can_do_setop(other)
1446-
14471443
other = _ensure_index(other)
14481444

14491445
if self.equals(other):
@@ -1492,9 +1488,7 @@ def difference(self, other):
14921488
14931489
>>> index.difference(index2)
14941490
"""
1495-
1496-
if not hasattr(other, '__iter__'):
1497-
raise TypeError('Input must be iterable!')
1491+
self._assert_can_do_setop(other)
14981492

14991493
if self.equals(other):
15001494
return Index([], name=self.name)
@@ -1517,7 +1511,7 @@ def sym_diff(self, other, result_name=None):
15171511
Parameters
15181512
----------
15191513
1520-
other : array-like
1514+
other : Index or array-like
15211515
result_name : str
15221516
15231517
Returns
@@ -1545,9 +1539,7 @@ def sym_diff(self, other, result_name=None):
15451539
>>> idx1 ^ idx2
15461540
Int64Index([1, 5], dtype='int64')
15471541
"""
1548-
if not hasattr(other, '__iter__'):
1549-
raise TypeError('Input must be iterable!')
1550-
1542+
self._assert_can_do_setop(other)
15511543
if not isinstance(other, Index):
15521544
other = Index(other)
15531545
result_name = result_name or self.name
@@ -5537,9 +5529,6 @@ def difference(self, other):
55375529
return MultiIndex.from_tuples(difference, sortorder=0,
55385530
names=result_names)
55395531

5540-
def _assert_can_do_setop(self, other):
5541-
pass
5542-
55435532
def astype(self, dtype):
55445533
if not is_object_dtype(np.dtype(dtype)):
55455534
raise TypeError('Setting %s dtype to anything other than object '

pandas/tests/test_index.py

+161-16
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,129 @@ def test_take(self):
251251
expected = ind[indexer]
252252
self.assertTrue(result.equals(expected))
253253

254+
def test_setops_errorcases(self):
255+
for name, idx in compat.iteritems(self.indices):
256+
# # non-iterable input
257+
cases = [0.5, 'xxx']
258+
methods = [idx.intersection, idx.union, idx.difference, idx.sym_diff]
259+
260+
for method in methods:
261+
for case in cases:
262+
assertRaisesRegexp(TypeError,
263+
"Input must be Index or array-like",
264+
method, case)
265+
266+
def test_intersection_base(self):
267+
for name, idx in compat.iteritems(self.indices):
268+
first = idx[:5]
269+
second = idx[:3]
270+
intersect = first.intersection(second)
271+
272+
if isinstance(idx, CategoricalIndex):
273+
pass
274+
else:
275+
self.assertTrue(tm.equalContents(intersect, second))
276+
277+
# GH 10149
278+
cases = [klass(second.values) for klass in [np.array, Series, list]]
279+
for case in cases:
280+
if isinstance(idx, PeriodIndex):
281+
msg = "can only call with other PeriodIndex-ed objects"
282+
with tm.assertRaisesRegexp(ValueError, msg):
283+
result = first.intersection(case)
284+
elif isinstance(idx, CategoricalIndex):
285+
pass
286+
elif isinstance(idx, MultiIndex):
287+
pass
288+
else:
289+
result = first.intersection(case)
290+
self.assertTrue(tm.equalContents(result, second))
291+
292+
def test_union_base(self):
293+
for name, idx in compat.iteritems(self.indices):
294+
first = idx[3:]
295+
second = idx[:5]
296+
everything = idx
297+
union = first.union(second)
298+
self.assertTrue(tm.equalContents(union, everything))
299+
300+
# GH 10149
301+
cases = [klass(second.values) for klass in [np.array, Series, list]]
302+
for case in cases:
303+
if isinstance(idx, PeriodIndex):
304+
msg = "can only call with other PeriodIndex-ed objects"
305+
with tm.assertRaisesRegexp(ValueError, msg):
306+
result = first.union(case)
307+
elif isinstance(idx, MultiIndex):
308+
pass
309+
elif isinstance(idx, CategoricalIndex):
310+
pass
311+
elif isinstance(idx, TimedeltaIndex):
312+
# checked by tdi._is_convertible_to_index
313+
pass
314+
else:
315+
result = first.union(case)
316+
self.assertTrue(tm.equalContents(result, everything))
317+
318+
def test_difference_base(self):
319+
for name, idx in compat.iteritems(self.indices):
320+
first = idx[2:]
321+
second = idx[:4]
322+
answer = idx[4:]
323+
result = first.difference(second)
324+
325+
if isinstance(idx, CategoricalIndex):
326+
pass
327+
else:
328+
self.assertTrue(tm.equalContents(result, answer))
329+
330+
# GH 10149
331+
cases = [klass(second.values) for klass in [np.array, Series, list]]
332+
for case in cases:
333+
if isinstance(idx, PeriodIndex):
334+
msg = "can only call with other PeriodIndex-ed objects"
335+
with tm.assertRaisesRegexp(ValueError, msg):
336+
result = first.difference(case)
337+
elif isinstance(idx, MultiIndex):
338+
pass
339+
elif isinstance(idx, CategoricalIndex):
340+
pass
341+
elif isinstance(idx, TimedeltaIndex):
342+
pass
343+
elif isinstance(idx, DatetimeIndex):
344+
# freq is not preserved even if possible
345+
self.assert_numpy_array_equal(result.asi8, answer.asi8)
346+
else:
347+
result = first.difference(case)
348+
self.assertTrue(tm.equalContents(result, answer))
349+
350+
def test_symmetric_diff(self):
351+
for name, idx in compat.iteritems(self.indices):
352+
first = idx[1:]
353+
second = idx[:-1]
354+
if isinstance(idx, CategoricalIndex):
355+
pass
356+
else:
357+
answer = idx[[0, -1]]
358+
result = first.sym_diff(second)
359+
self.assertTrue(tm.equalContents(result, answer))
360+
361+
# GH 10149
362+
cases = [klass(second.values) for klass in [np.array, Series, list]]
363+
for case in cases:
364+
if isinstance(idx, PeriodIndex):
365+
msg = "can only call with other PeriodIndex-ed objects"
366+
with tm.assertRaisesRegexp(ValueError, msg):
367+
result = first.sym_diff(case)
368+
elif isinstance(idx, MultiIndex):
369+
pass
370+
elif isinstance(idx, CategoricalIndex):
371+
pass
372+
else:
373+
result = first.sym_diff(case)
374+
self.assertTrue(tm.equalContents(result, answer))
375+
376+
254377
class TestIndex(Base, tm.TestCase):
255378
_holder = Index
256379
_multiprocess_can_split_ = True
@@ -620,16 +743,12 @@ def test_intersection(self):
620743
first = self.strIndex[:20]
621744
second = self.strIndex[:10]
622745
intersect = first.intersection(second)
623-
624746
self.assertTrue(tm.equalContents(intersect, second))
625747

626748
# Corner cases
627749
inter = first.intersection(first)
628750
self.assertIs(inter, first)
629751

630-
# non-iterable input
631-
assertRaisesRegexp(TypeError, "iterable", first.intersection, 0.5)
632-
633752
idx1 = Index([1, 2, 3, 4, 5], name='idx')
634753
# if target has the same name, it is preserved
635754
idx2 = Index([3, 4, 5, 6, 7], name='idx')
@@ -671,6 +790,12 @@ def test_union(self):
671790
union = first.union(second)
672791
self.assertTrue(tm.equalContents(union, everything))
673792

793+
# GH 10149
794+
cases = [klass(second.values) for klass in [np.array, Series, list]]
795+
for case in cases:
796+
result = first.union(case)
797+
self.assertTrue(tm.equalContents(result, everything))
798+
674799
# Corner cases
675800
union = first.union(first)
676801
self.assertIs(union, first)
@@ -681,9 +806,6 @@ def test_union(self):
681806
union = Index([]).union(first)
682807
self.assertIs(union, first)
683808

684-
# non-iterable input
685-
assertRaisesRegexp(TypeError, "iterable", first.union, 0.5)
686-
687809
# preserve names
688810
first.name = 'A'
689811
second.name = 'A'
@@ -792,11 +914,7 @@ def test_difference(self):
792914
self.assertEqual(len(result), 0)
793915
self.assertEqual(result.name, first.name)
794916

795-
# non-iterable input
796-
assertRaisesRegexp(TypeError, "iterable", first.difference, 0.5)
797-
798917
def test_symmetric_diff(self):
799-
800918
# smoke
801919
idx1 = Index([1, 2, 3, 4], name='idx1')
802920
idx2 = Index([2, 3, 4, 5])
@@ -842,10 +960,6 @@ def test_symmetric_diff(self):
842960
self.assertTrue(tm.equalContents(result, expected))
843961
self.assertEqual(result.name, 'new_name')
844962

845-
# other isn't iterable
846-
with tm.assertRaises(TypeError):
847-
Index(idx1,dtype='object').difference(1)
848-
849963
def test_is_numeric(self):
850964
self.assertFalse(self.dateIndex.is_numeric())
851965
self.assertFalse(self.strIndex.is_numeric())
@@ -1786,6 +1900,7 @@ def test_equals(self):
17861900
self.assertFalse(CategoricalIndex(list('aabca') + [np.nan],categories=['c','a','b',np.nan]).equals(list('aabca')))
17871901
self.assertTrue(CategoricalIndex(list('aabca') + [np.nan],categories=['c','a','b',np.nan]).equals(list('aabca') + [np.nan]))
17881902

1903+
17891904
class Numeric(Base):
17901905

17911906
def test_numeric_compat(self):
@@ -2642,6 +2757,36 @@ def test_time_overflow_for_32bit_machines(self):
26422757
idx2 = pd.date_range(end='2000', periods=periods, freq='S')
26432758
self.assertEqual(len(idx2), periods)
26442759

2760+
def test_intersection(self):
2761+
first = self.index
2762+
second = self.index[5:]
2763+
intersect = first.intersection(second)
2764+
self.assertTrue(tm.equalContents(intersect, second))
2765+
2766+
# GH 10149
2767+
cases = [klass(second.values) for klass in [np.array, Series, list]]
2768+
for case in cases:
2769+
result = first.intersection(case)
2770+
self.assertTrue(tm.equalContents(result, second))
2771+
2772+
third = Index(['a', 'b', 'c'])
2773+
result = first.intersection(third)
2774+
expected = pd.Index([], dtype=object)
2775+
self.assert_index_equal(result, expected)
2776+
2777+
def test_union(self):
2778+
first = self.index[:5]
2779+
second = self.index[5:]
2780+
everything = self.index
2781+
union = first.union(second)
2782+
self.assertTrue(tm.equalContents(union, everything))
2783+
2784+
# GH 10149
2785+
cases = [klass(second.values) for klass in [np.array, Series, list]]
2786+
for case in cases:
2787+
result = first.union(case)
2788+
self.assertTrue(tm.equalContents(result, everything))
2789+
26452790

26462791
class TestPeriodIndex(DatetimeLike, tm.TestCase):
26472792
_holder = PeriodIndex
@@ -2652,7 +2797,7 @@ def setUp(self):
26522797
self.setup_indices()
26532798

26542799
def create_index(self):
2655-
return period_range('20130101',periods=5,freq='D')
2800+
return period_range('20130101', periods=5, freq='D')
26562801

26572802
def test_pickle_compat_construction(self):
26582803
pass

pandas/tseries/index.py

+2
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ def union(self, other):
804804
-------
805805
y : Index or DatetimeIndex
806806
"""
807+
self._assert_can_do_setop(other)
807808
if not isinstance(other, DatetimeIndex):
808809
try:
809810
other = DatetimeIndex(other)
@@ -1039,6 +1040,7 @@ def intersection(self, other):
10391040
-------
10401041
y : Index or DatetimeIndex
10411042
"""
1043+
self._assert_can_do_setop(other)
10421044
if not isinstance(other, DatetimeIndex):
10431045
try:
10441046
other = DatetimeIndex(other)

pandas/tseries/period.py

+2
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,8 @@ def join(self, other, how='left', level=None, return_indexers=False):
680680
return self._apply_meta(result)
681681

682682
def _assert_can_do_setop(self, other):
683+
super(PeriodIndex, self)._assert_can_do_setop(other)
684+
683685
if not isinstance(other, PeriodIndex):
684686
raise ValueError('can only call with other PeriodIndex-ed objects')
685687

pandas/tseries/tdi.py

+2
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def union(self, other):
436436
-------
437437
y : Index or TimedeltaIndex
438438
"""
439+
self._assert_can_do_setop(other)
439440
if _is_convertible_to_index(other):
440441
try:
441442
other = TimedeltaIndex(other)
@@ -581,6 +582,7 @@ def intersection(self, other):
581582
-------
582583
y : Index or TimedeltaIndex
583584
"""
585+
self._assert_can_do_setop(other)
584586
if not isinstance(other, TimedeltaIndex):
585587
try:
586588
other = TimedeltaIndex(other)

0 commit comments

Comments
 (0)