Skip to content

Commit 47d3ba1

Browse files
committed
fix inconsistent index naming with union/intersect GH35847
1 parent c33c3c0 commit 47d3ba1

File tree

16 files changed

+197
-72
lines changed

16 files changed

+197
-72
lines changed

doc/source/user_guide/merging.rst

+8
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ functionality below.
141141
frames = [ process_your_file(f) for f in files ]
142142
result = pd.concat(frames)
143143

144+
.. note::
145+
146+
When concatenating DataFrames with named axes, pandas will attempt to preserve
147+
these index/column names whenever possible. In the case where all inputs share a
148+
common name, this name will be assigned to the result. When the input names do
149+
not all agree, the result will be unnamed. The same is true for :class:`MultiIndex`,
150+
but the logic is applied separately on a level-by-level basis.
151+
144152

145153
Set logic on the other axes
146154
~~~~~~~~~~~~~~~~~~~~~~~~~~~

doc/source/whatsnew/v1.2.0.rst

+20
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,26 @@ Beginning with this version, the default is now to use the more accurate parser
109109
``floating_precision="legacy"`` to use the legacy parser. The change to using the higher precision
110110
parser by default should have no impact on performance. (:issue:`17154`)
111111

112+
.. _whatsnew_120.index_name_preservation:
113+
114+
Index/column name preservation when aggregating
115+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
116+
117+
When aggregating using :meth:`concat` or the :class:`DataFrame` constructor, Pandas
118+
will attempt to preserve index (and column) names whenever possible (:issue:`35847`).
119+
In the case where all inputs share a common name, this name will be assigned to the
120+
result. When the input names do not all agree, the result will be unnamed. Here is an
121+
example where the index name is preserved:
122+
123+
.. ipython:: python
124+
125+
idx = pd.Index(range(5), name='abc')
126+
ser = pd.Series(range(5, 10), index=idx)
127+
pd.concat({'x': ser[1:], 'y': ser[:-1]}, axis=1)
128+
129+
The same is true for :class:`MultiIndex`, but the logic is applied separately on a
130+
level-by-level basis.
131+
112132
.. _whatsnew_120.enhancements.other:
113133

114134
Other enhancements

pandas/core/indexes/api.py

+4-28
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from pandas._libs import NaT, lib
55
from pandas.errors import InvalidIndexError
66

7-
import pandas.core.common as com
87
from pandas.core.indexes.base import (
98
Index,
109
_new_Index,
1110
ensure_index,
1211
ensure_index_from_sequences,
12+
get_unanimous_names,
1313
)
1414
from pandas.core.indexes.category import CategoricalIndex
1515
from pandas.core.indexes.datetimes import DatetimeIndex
@@ -57,7 +57,7 @@
5757
"ensure_index_from_sequences",
5858
"get_objs_combined_axis",
5959
"union_indexes",
60-
"get_consensus_names",
60+
"get_unanimous_names",
6161
"all_indexes_same",
6262
]
6363

@@ -221,9 +221,9 @@ def conv(i):
221221
if not all(index.equals(other) for other in indexes[1:]):
222222
index = _unique_indices(indexes)
223223

224-
name = get_consensus_names(indexes)[0]
224+
name = get_unanimous_names(*indexes)[0]
225225
if name != index.name:
226-
index = index._shallow_copy(name=name)
226+
index = index.rename(name)
227227
return index
228228
else: # kind='list'
229229
return _unique_indices(indexes)
@@ -267,30 +267,6 @@ def _sanitize_and_check(indexes):
267267
return indexes, "array"
268268

269269

270-
def get_consensus_names(indexes):
271-
"""
272-
Give a consensus 'names' to indexes.
273-
274-
If there's exactly one non-empty 'names', return this,
275-
otherwise, return empty.
276-
277-
Parameters
278-
----------
279-
indexes : list of Index objects
280-
281-
Returns
282-
-------
283-
list
284-
A list representing the consensus 'names' found.
285-
"""
286-
# find the non-none names, need to tupleify to make
287-
# the set hashable, then reverse on return
288-
consensus_names = {tuple(i.names) for i in indexes if com.any_not_none(*i.names)}
289-
if len(consensus_names) == 1:
290-
return list(list(consensus_names)[0])
291-
return [None] * indexes[0].nlevels
292-
293-
294270
def all_indexes_same(indexes):
295271
"""
296272
Determine if all indexes contain the same elements.

pandas/core/indexes/base.py

+38-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import copy as copy_func
22
from datetime import datetime
3+
from itertools import zip_longest
34
import operator
45
from textwrap import dedent
56
from typing import (
@@ -11,6 +12,8 @@
1112
List,
1213
Optional,
1314
Sequence,
15+
Tuple,
16+
Type,
1417
TypeVar,
1518
Union,
1619
)
@@ -2492,7 +2495,7 @@ def _get_reconciled_name_object(self, other):
24922495
"""
24932496
name = get_op_result_name(self, other)
24942497
if self.name != name:
2495-
return self._shallow_copy(name=name)
2498+
return self.rename(name)
24962499
return self
24972500

24982501
def _union_incompatible_dtypes(self, other, sort):
@@ -2600,7 +2603,9 @@ def union(self, other, sort=None):
26002603
if not self._can_union_without_object_cast(other):
26012604
return self._union_incompatible_dtypes(other, sort=sort)
26022605

2603-
return self._union(other, sort=sort)
2606+
result = self._union(other, sort=sort)
2607+
2608+
return self._wrap_setop_result(other, result)
26042609

26052610
def _union(self, other, sort):
26062611
"""
@@ -2622,10 +2627,10 @@ def _union(self, other, sort):
26222627
Index
26232628
"""
26242629
if not len(other) or self.equals(other):
2625-
return self._get_reconciled_name_object(other)
2630+
return self
26262631

26272632
if not len(self):
2628-
return other._get_reconciled_name_object(self)
2633+
return other
26292634

26302635
# TODO(EA): setops-refactor, clean all this up
26312636
lvals = self._values
@@ -2667,12 +2672,16 @@ def _union(self, other, sort):
26672672
stacklevel=3,
26682673
)
26692674

2670-
# for subclasses
2671-
return self._wrap_setop_result(other, result)
2675+
return self._shallow_copy(result)
26722676

26732677
def _wrap_setop_result(self, other, result):
26742678
name = get_op_result_name(self, other)
2675-
return self._shallow_copy(result, name=name)
2679+
if isinstance(result, Index):
2680+
if result.name != name:
2681+
return result.rename(name)
2682+
return result
2683+
else:
2684+
return self._shallow_copy(result, name=name)
26762685

26772686
# TODO: standardize return type of non-union setops type(self vs other)
26782687
def intersection(self, other, sort=False):
@@ -2742,15 +2751,12 @@ def intersection(self, other, sort=False):
27422751
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
27432752
indexer = indexer[indexer != -1]
27442753

2745-
taken = other.take(indexer)
2746-
res_name = get_op_result_name(self, other)
2754+
result = other.take(indexer)
27472755

27482756
if sort is None:
2749-
taken = algos.safe_sort(taken.values)
2750-
return self._shallow_copy(taken, name=res_name)
2757+
result = algos.safe_sort(result.values)
27512758

2752-
taken.name = res_name
2753-
return taken
2759+
return self._wrap_setop_result(other, result)
27542760

27552761
def difference(self, other, sort=None):
27562762
"""
@@ -5935,3 +5941,22 @@ def _maybe_asobject(dtype, klass, data, copy: bool, name: Label, **kwargs):
59355941
return index.astype(object)
59365942

59375943
return klass(data, dtype=dtype, copy=copy, name=name, **kwargs)
5944+
5945+
5946+
def get_unanimous_names(*indexes: Type[Index]) -> Tuple[Any, ...]:
5947+
"""
5948+
Return common name if all indices agree, otherwise None (level-by-level).
5949+
5950+
Parameters
5951+
----------
5952+
indexes : list of Index objects
5953+
5954+
Returns
5955+
-------
5956+
list
5957+
A list representing the unanimous 'names' found.
5958+
"""
5959+
name_tups = [tuple(i.names) for i in indexes]
5960+
name_sets = [{*ns} for ns in zip_longest(*name_tups)]
5961+
names = tuple(ns.pop() if len(ns) == 1 else None for ns in name_sets)
5962+
return names

pandas/core/indexes/datetimelike.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -719,33 +719,29 @@ def intersection(self, other, sort=False):
719719
"""
720720
self._validate_sort_keyword(sort)
721721
self._assert_can_do_setop(other)
722-
res_name = get_op_result_name(self, other)
723722

724723
if self.equals(other):
725724
return self._get_reconciled_name_object(other)
726725

727726
if len(self) == 0:
728-
return self.copy()
727+
return self.copy()._get_reconciled_name_object(other)
729728
if len(other) == 0:
730-
return other.copy()
729+
return other.copy()._get_reconciled_name_object(self)
731730

732731
if not isinstance(other, type(self)):
733732
result = Index.intersection(self, other, sort=sort)
734733
if isinstance(result, type(self)):
735734
if result.freq is None:
736735
# TODO: no tests rely on this; needed?
737736
result = result._with_freq("infer")
738-
result.name = res_name
739737
return result
740738

741739
elif not self._can_fast_intersect(other):
742740
result = Index.intersection(self, other, sort=sort)
743741
# We need to invalidate the freq because Index.intersection
744742
# uses _shallow_copy on a view of self._data, which will preserve
745743
# self.freq if we're not careful.
746-
result = result._with_freq(None)._with_freq("infer")
747-
result.name = res_name
748-
return result
744+
return result._with_freq(None)._with_freq("infer")
749745

750746
# to make our life easier, "sort" the two ranges
751747
if self[0] <= other[0]:
@@ -759,11 +755,13 @@ def intersection(self, other, sort=False):
759755
start = right[0]
760756

761757
if end < start:
762-
return type(self)(data=[], dtype=self.dtype, freq=self.freq, name=res_name)
758+
result = type(self)(data=[], dtype=self.dtype, freq=self.freq)
763759
else:
764760
lslice = slice(*left.slice_locs(start, end))
765761
left_chunk = left._values[lslice]
766-
return type(self)._simple_new(left_chunk, name=res_name)
762+
result = type(self)._simple_new(left_chunk)
763+
764+
return self._wrap_setop_result(other, result)
767765

768766
def _can_fast_intersect(self: _T, other: _T) -> bool:
769767
if self.freq is None:
@@ -858,7 +856,7 @@ def _fast_union(self, other, sort=None):
858856
# The can_fast_union check ensures that the result.freq
859857
# should match self.freq
860858
dates = type(self._data)(dates, freq=self.freq)
861-
result = type(self)._simple_new(dates, name=self.name)
859+
result = type(self)._simple_new(dates)
862860
return result
863861
else:
864862
return left
@@ -883,8 +881,8 @@ def _union(self, other, sort):
883881
result = result._with_freq("infer")
884882
return result
885883
else:
886-
i8self = Int64Index._simple_new(self.asi8, name=self.name)
887-
i8other = Int64Index._simple_new(other.asi8, name=other.name)
884+
i8self = Int64Index._simple_new(self.asi8)
885+
i8other = Int64Index._simple_new(other.asi8)
888886
i8result = i8self._union(i8other, sort=sort)
889887
result = type(self)(i8result, dtype=self.dtype, freq="infer")
890888
return result

pandas/core/indexes/datetimes.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from pandas.core.arrays.datetimes import DatetimeArray, tz_to_dtype
3333
import pandas.core.common as com
34-
from pandas.core.indexes.base import Index, maybe_extract_name
34+
from pandas.core.indexes.base import Index, get_unanimous_names, maybe_extract_name
3535
from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin
3636
from pandas.core.indexes.extension import inherit_names
3737
from pandas.core.tools.times import to_time
@@ -383,6 +383,10 @@ def union_many(self, others):
383383
this = this._fast_union(other)
384384
else:
385385
this = Index.union(this, other)
386+
387+
res_name = get_unanimous_names(self, *others)[0]
388+
if this.name != res_name:
389+
return this.rename(res_name)
386390
return this
387391

388392
# --------------------------------------------------------------------

pandas/core/indexes/interval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ def intersection(
10181018
if sort is None:
10191019
taken = taken.sort_values()
10201020

1021-
return taken
1021+
return self._wrap_setop_result(other, taken)
10221022

10231023
def _intersection_unique(self, other: "IntervalIndex") -> "IntervalIndex":
10241024
"""

pandas/core/indexes/multi.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@
4747
from pandas.core.arrays.categorical import factorize_from_iterables
4848
import pandas.core.common as com
4949
import pandas.core.indexes.base as ibase
50-
from pandas.core.indexes.base import Index, _index_shared_docs, ensure_index
50+
from pandas.core.indexes.base import (
51+
Index,
52+
_index_shared_docs,
53+
ensure_index,
54+
get_unanimous_names,
55+
)
5156
from pandas.core.indexes.frozen import FrozenList
5257
from pandas.core.indexes.numeric import Int64Index
5358
import pandas.core.missing as missing
@@ -3415,7 +3420,7 @@ def union(self, other, sort=None):
34153420
other, result_names = self._convert_can_do_setop(other)
34163421

34173422
if len(other) == 0 or self.equals(other):
3418-
return self
3423+
return self.rename(result_names)
34193424

34203425
# TODO: Index.union returns other when `len(self)` is 0.
34213426

@@ -3457,7 +3462,7 @@ def intersection(self, other, sort=False):
34573462
other, result_names = self._convert_can_do_setop(other)
34583463

34593464
if self.equals(other):
3460-
return self
3465+
return self.rename(result_names)
34613466

34623467
if not is_object_dtype(other.dtype):
34633468
# The intersection is empty
@@ -3528,7 +3533,7 @@ def difference(self, other, sort=None):
35283533
other, result_names = self._convert_can_do_setop(other)
35293534

35303535
if len(other) == 0:
3531-
return self
3536+
return self.rename(result_names)
35323537

35333538
if self.equals(other):
35343539
return MultiIndex(
@@ -3576,7 +3581,8 @@ def _convert_can_do_setop(self, other):
35763581
except TypeError as err:
35773582
raise TypeError(msg) from err
35783583
else:
3579-
result_names = self.names if self.names == other.names else None
3584+
result_names = get_unanimous_names(self, other)
3585+
35803586
return other, result_names
35813587

35823588
# --------------------------------------------------------------------

pandas/core/indexes/range.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,8 @@ def intersection(self, other, sort=False):
534534
new_index = new_index[::-1]
535535
if sort is None:
536536
new_index = new_index.sort_values()
537-
return new_index
537+
538+
return self._wrap_setop_result(other, new_index)
538539

539540
def _min_fitting_element(self, lower_limit: int) -> int:
540541
"""Returns the smallest element greater than or equal to the limit"""

0 commit comments

Comments
 (0)