@@ -580,8 +580,18 @@ def to_datetime(self, dayfirst=False):
580
580
return DatetimeIndex (self .values )
581
581
582
582
def _assert_can_do_setop (self , other ):
583
+ if not com .is_list_like (other ):
584
+ raise TypeError ('Input must be Index or array-like' )
583
585
return True
584
586
587
+ def _convert_can_do_setop (self , other ):
588
+ if not isinstance (other , Index ):
589
+ other = Index (other , name = self .name )
590
+ result_name = self .name
591
+ else :
592
+ result_name = self .name if self .name == other .name else None
593
+ return other , result_name
594
+
585
595
@property
586
596
def nlevels (self ):
587
597
return 1
@@ -1364,16 +1374,14 @@ def union(self, other):
1364
1374
-------
1365
1375
union : Index
1366
1376
"""
1367
- if not hasattr (other , '__iter__' ):
1368
- raise TypeError ( 'Input must be iterable.' )
1377
+ self . _assert_can_do_setop (other )
1378
+ other = _ensure_index ( other )
1369
1379
1370
1380
if len (other ) == 0 or self .equals (other ):
1371
1381
return self
1372
1382
1373
1383
if len (self ) == 0 :
1374
- return _ensure_index (other )
1375
-
1376
- self ._assert_can_do_setop (other )
1384
+ return other
1377
1385
1378
1386
if not is_dtype_equal (self .dtype ,other .dtype ):
1379
1387
this = self .astype ('O' )
@@ -1439,11 +1447,7 @@ def intersection(self, other):
1439
1447
-------
1440
1448
intersection : Index
1441
1449
"""
1442
- if not hasattr (other , '__iter__' ):
1443
- raise TypeError ('Input must be iterable!' )
1444
-
1445
1450
self ._assert_can_do_setop (other )
1446
-
1447
1451
other = _ensure_index (other )
1448
1452
1449
1453
if self .equals (other ):
@@ -1492,18 +1496,12 @@ def difference(self, other):
1492
1496
1493
1497
>>> index.difference(index2)
1494
1498
"""
1495
-
1496
- if not hasattr (other , '__iter__' ):
1497
- raise TypeError ('Input must be iterable!' )
1499
+ self ._assert_can_do_setop (other )
1498
1500
1499
1501
if self .equals (other ):
1500
1502
return Index ([], name = self .name )
1501
1503
1502
- if not isinstance (other , Index ):
1503
- other = np .asarray (other )
1504
- result_name = self .name
1505
- else :
1506
- result_name = self .name if self .name == other .name else None
1504
+ other , result_name = self ._convert_can_do_setop (other )
1507
1505
1508
1506
theDiff = sorted (set (self ) - set (other ))
1509
1507
return Index (theDiff , name = result_name )
@@ -1517,7 +1515,7 @@ def sym_diff(self, other, result_name=None):
1517
1515
Parameters
1518
1516
----------
1519
1517
1520
- other : array-like
1518
+ other : Index or array-like
1521
1519
result_name : str
1522
1520
1523
1521
Returns
@@ -1545,13 +1543,10 @@ def sym_diff(self, other, result_name=None):
1545
1543
>>> idx1 ^ idx2
1546
1544
Int64Index([1, 5], dtype='int64')
1547
1545
"""
1548
- if not hasattr (other , '__iter__' ):
1549
- raise TypeError ('Input must be iterable!' )
1550
-
1551
- if not isinstance (other , Index ):
1552
- other = Index (other )
1553
- result_name = result_name or self .name
1554
-
1546
+ self ._assert_can_do_setop (other )
1547
+ other , result_name_update = self ._convert_can_do_setop (other )
1548
+ if result_name is None :
1549
+ result_name = result_name_update
1555
1550
the_diff = sorted (set ((self .difference (other )).union (other .difference (self ))))
1556
1551
return Index (the_diff , name = result_name )
1557
1552
@@ -5460,12 +5455,11 @@ def union(self, other):
5460
5455
>>> index.union(index2)
5461
5456
"""
5462
5457
self ._assert_can_do_setop (other )
5458
+ other , result_names = self ._convert_can_do_setop (other )
5463
5459
5464
5460
if len (other ) == 0 or self .equals (other ):
5465
5461
return self
5466
5462
5467
- result_names = self .names if self .names == other .names else None
5468
-
5469
5463
uniq_tuples = lib .fast_unique_multiple ([self .values , other .values ])
5470
5464
return MultiIndex .from_arrays (lzip (* uniq_tuples ), sortorder = 0 ,
5471
5465
names = result_names )
@@ -5483,12 +5477,11 @@ def intersection(self, other):
5483
5477
Index
5484
5478
"""
5485
5479
self ._assert_can_do_setop (other )
5480
+ other , result_names = self ._convert_can_do_setop (other )
5486
5481
5487
5482
if self .equals (other ):
5488
5483
return self
5489
5484
5490
- result_names = self .names if self .names == other .names else None
5491
-
5492
5485
self_tuples = self .values
5493
5486
other_tuples = other .values
5494
5487
uniq_tuples = sorted (set (self_tuples ) & set (other_tuples ))
@@ -5509,18 +5502,10 @@ def difference(self, other):
5509
5502
diff : MultiIndex
5510
5503
"""
5511
5504
self ._assert_can_do_setop (other )
5505
+ other , result_names = self ._convert_can_do_setop (other )
5512
5506
5513
- if not isinstance (other , MultiIndex ):
5514
- if len (other ) == 0 :
5507
+ if len (other ) == 0 :
5515
5508
return self
5516
- try :
5517
- other = MultiIndex .from_tuples (other )
5518
- except :
5519
- raise TypeError ('other must be a MultiIndex or a list of'
5520
- ' tuples' )
5521
- result_names = self .names
5522
- else :
5523
- result_names = self .names if self .names == other .names else None
5524
5509
5525
5510
if self .equals (other ):
5526
5511
return MultiIndex (levels = [[]] * self .nlevels ,
@@ -5537,15 +5522,30 @@ def difference(self, other):
5537
5522
return MultiIndex .from_tuples (difference , sortorder = 0 ,
5538
5523
names = result_names )
5539
5524
5540
- def _assert_can_do_setop (self , other ):
5541
- pass
5542
-
5543
5525
def astype (self , dtype ):
5544
5526
if not is_object_dtype (np .dtype (dtype )):
5545
5527
raise TypeError ('Setting %s dtype to anything other than object '
5546
5528
'is not supported' % self .__class__ )
5547
5529
return self ._shallow_copy ()
5548
5530
5531
+ def _convert_can_do_setop (self , other ):
5532
+ result_names = self .names
5533
+
5534
+ if not hasattr (other , 'names' ):
5535
+ if len (other ) == 0 :
5536
+ other = MultiIndex (levels = [[]] * self .nlevels ,
5537
+ labels = [[]] * self .nlevels ,
5538
+ verify_integrity = False )
5539
+ else :
5540
+ msg = 'other must be a MultiIndex or a list of tuples'
5541
+ try :
5542
+ other = MultiIndex .from_tuples (other )
5543
+ except :
5544
+ raise TypeError (msg )
5545
+ else :
5546
+ result_names = self .names if self .names == other .names else None
5547
+ return other , result_names
5548
+
5549
5549
def insert (self , loc , item ):
5550
5550
"""
5551
5551
Make new MultiIndex inserting new item at location
0 commit comments