@@ -1121,7 +1121,7 @@ def _convert_can_do_setop(self, other):
1121
1121
other = Index (other , name = self .name )
1122
1122
result_name = self .name
1123
1123
else :
1124
- result_name = self .name if self . name == other . name else None
1124
+ result_name = self ._get_intersection_name ( other )
1125
1125
return other , result_name
1126
1126
1127
1127
def _convert_for_op (self , value ):
@@ -2190,19 +2190,35 @@ def __or__(self, other):
2190
2190
def __xor__ (self , other ):
2191
2191
return self .symmetric_difference (other )
2192
2192
2193
- def _get_consensus_name (self , other ):
2193
+ def _get_union_name (self , other ):
2194
+ # GH 9943 9862
2194
2195
"""
2195
- Given 2 indexes, give a consensus name meaning
2196
+ Given 2 indexes, give the union name meaning
2196
2197
we take the not None one, or None if the names differ.
2197
- Return a new object if we are resetting the name
2198
2198
"""
2199
2199
if self .name != other .name :
2200
2200
if self .name is None or other .name is None :
2201
2201
name = self .name or other .name
2202
2202
else :
2203
2203
name = None
2204
- if self .name != name :
2205
- return self ._shallow_copy (name = name )
2204
+ else :
2205
+ name = self .name
2206
+ return name
2207
+
2208
+ def _get_intersection_name (self , other ):
2209
+ # GH 9943 9862
2210
+ return self .name if self .name == other .name else None
2211
+
2212
+ def _get_consensus_name_object (self , other , name_converter ):
2213
+ """
2214
+ Given 2 indexes, give a consensus name meaning
2215
+ we use the name converter (either _get_union_name or
2216
+ get_intersection_name) to determine the name.
2217
+ Return a new object if we are resetting the name
2218
+ """
2219
+ name = name_converter (other )
2220
+ if self .name != name :
2221
+ return self ._shallow_copy (name = name )
2206
2222
return self
2207
2223
2208
2224
def union (self , other ):
@@ -2230,10 +2246,12 @@ def union(self, other):
2230
2246
other = _ensure_index (other )
2231
2247
2232
2248
if len (other ) == 0 or self .equals (other ):
2233
- return self ._get_consensus_name (other )
2249
+ return self ._get_consensus_name_object (other ,
2250
+ self ._get_union_name )
2234
2251
2235
2252
if len (self ) == 0 :
2236
- return other ._get_consensus_name (self )
2253
+ return other ._get_consensus_name_object (self ,
2254
+ other ._get_union_name )
2237
2255
2238
2256
# TODO: is_dtype_union_equal is a hack around
2239
2257
# 1. buggy set ops with duplicates (GH #13432)
@@ -2296,11 +2314,15 @@ def union(self, other):
2296
2314
stacklevel = 3 )
2297
2315
2298
2316
# for subclasses
2299
- return self ._wrap_union_result (other , result )
2317
+ return self ._wrap_setop_result (other , result , self . _get_union_name )
2300
2318
2301
- def _wrap_union_result (self , other , result ):
2302
- name = self .name if self .name == other .name else None
2303
- return self .__class__ (result , name = name )
2319
+ def _wrap_setop_result (self , other , result , name_func ):
2320
+ # GH 9943 9862
2321
+ """
2322
+ name_func is either self._get_union_name or
2323
+ self._get_intersection_name
2324
+ """
2325
+ return self .__class__ (result , name = name_func (other ))
2304
2326
2305
2327
def intersection (self , other ):
2306
2328
"""
@@ -2330,7 +2352,8 @@ def intersection(self, other):
2330
2352
other = _ensure_index (other )
2331
2353
2332
2354
if self .equals (other ):
2333
- return self ._get_consensus_name (other )
2355
+ return self ._get_consensus_name_object (other ,
2356
+ self ._get_intersection_name )
2334
2357
2335
2358
if not is_dtype_equal (self .dtype , other .dtype ):
2336
2359
this = self .astype ('O' )
@@ -2350,7 +2373,8 @@ def intersection(self, other):
2350
2373
if self .is_monotonic and other .is_monotonic :
2351
2374
try :
2352
2375
result = self ._inner_indexer (lvals , rvals )[0 ]
2353
- return self ._wrap_union_result (other , result )
2376
+ return self ._wrap_setop_result (other , result ,
2377
+ self ._get_intersection_name )
2354
2378
except TypeError :
2355
2379
pass
2356
2380
@@ -3479,7 +3503,7 @@ def _join_monotonic(self, other, how='left', return_indexers=False):
3479
3503
return join_index
3480
3504
3481
3505
def _wrap_joined_index (self , joined , other ):
3482
- name = self .name if self . name == other . name else None
3506
+ name = self ._get_intersection_name ( other )
3483
3507
return Index (joined , name = name )
3484
3508
3485
3509
def _get_string_slice (self , key , use_lhs = True , use_rhs = True ):
0 commit comments