1
1
from copy import copy as copy_func
2
2
from datetime import datetime
3
+ from itertools import zip_longest
3
4
import operator
4
5
from textwrap import dedent
5
6
from typing import (
11
12
List ,
12
13
Optional ,
13
14
Sequence ,
15
+ Tuple ,
16
+ Type ,
14
17
TypeVar ,
15
18
Union ,
16
19
)
@@ -2492,7 +2495,7 @@ def _get_reconciled_name_object(self, other):
2492
2495
"""
2493
2496
name = get_op_result_name (self , other )
2494
2497
if self .name != name :
2495
- return self ._shallow_copy ( name = name )
2498
+ return self .rename ( name )
2496
2499
return self
2497
2500
2498
2501
def _union_incompatible_dtypes (self , other , sort ):
@@ -2600,7 +2603,9 @@ def union(self, other, sort=None):
2600
2603
if not self ._can_union_without_object_cast (other ):
2601
2604
return self ._union_incompatible_dtypes (other , sort = sort )
2602
2605
2603
- return self ._union (other , sort = sort )
2606
+ result = self ._union (other , sort = sort )
2607
+
2608
+ return self ._wrap_setop_result (other , result )
2604
2609
2605
2610
def _union (self , other , sort ):
2606
2611
"""
@@ -2622,10 +2627,10 @@ def _union(self, other, sort):
2622
2627
Index
2623
2628
"""
2624
2629
if not len (other ) or self .equals (other ):
2625
- return self . _get_reconciled_name_object ( other )
2630
+ return self
2626
2631
2627
2632
if not len (self ):
2628
- return other . _get_reconciled_name_object ( self )
2633
+ return other
2629
2634
2630
2635
# TODO(EA): setops-refactor, clean all this up
2631
2636
lvals = self ._values
@@ -2667,12 +2672,16 @@ def _union(self, other, sort):
2667
2672
stacklevel = 3 ,
2668
2673
)
2669
2674
2670
- # for subclasses
2671
- return self ._wrap_setop_result (other , result )
2675
+ return self ._shallow_copy (result )
2672
2676
2673
2677
def _wrap_setop_result (self , other , result ):
2674
2678
name = get_op_result_name (self , other )
2675
- return self ._shallow_copy (result , name = name )
2679
+ if isinstance (result , Index ):
2680
+ if result .name != name :
2681
+ return result .rename (name )
2682
+ return result
2683
+ else :
2684
+ return self ._shallow_copy (result , name = name )
2676
2685
2677
2686
# TODO: standardize return type of non-union setops type(self vs other)
2678
2687
def intersection (self , other , sort = False ):
@@ -2742,15 +2751,12 @@ def intersection(self, other, sort=False):
2742
2751
indexer = algos .unique1d (Index (rvals ).get_indexer_non_unique (lvals )[0 ])
2743
2752
indexer = indexer [indexer != - 1 ]
2744
2753
2745
- taken = other .take (indexer )
2746
- res_name = get_op_result_name (self , other )
2754
+ result = other .take (indexer )
2747
2755
2748
2756
if sort is None :
2749
- taken = algos .safe_sort (taken .values )
2750
- return self ._shallow_copy (taken , name = res_name )
2757
+ result = algos .safe_sort (result .values )
2751
2758
2752
- taken .name = res_name
2753
- return taken
2759
+ return self ._wrap_setop_result (other , result )
2754
2760
2755
2761
def difference (self , other , sort = None ):
2756
2762
"""
@@ -5935,3 +5941,22 @@ def _maybe_asobject(dtype, klass, data, copy: bool, name: Label, **kwargs):
5935
5941
return index .astype (object )
5936
5942
5937
5943
return klass (data , dtype = dtype , copy = copy , name = name , ** kwargs )
5944
+
5945
+
5946
+ def get_unanimous_names (* indexes : Type [Index ]) -> Tuple [Any , ...]:
5947
+ """
5948
+ Return common name if all indices agree, otherwise None (level-by-level).
5949
+
5950
+ Parameters
5951
+ ----------
5952
+ indexes : list of Index objects
5953
+
5954
+ Returns
5955
+ -------
5956
+ list
5957
+ A list representing the unanimous 'names' found.
5958
+ """
5959
+ name_tups = [tuple (i .names ) for i in indexes ]
5960
+ name_sets = [{* ns } for ns in zip_longest (* name_tups )]
5961
+ names = tuple (ns .pop () if len (ns ) == 1 else None for ns in name_sets )
5962
+ return names
0 commit comments