Skip to content

Commit 3a8aa5e

Browse files
committed
allow for top and mid-level assignment to DataFrames with MultIndex columns
1 parent 03e1c89 commit 3a8aa5e

File tree

4 files changed

+169
-18
lines changed

4 files changed

+169
-18
lines changed

pandas/core/frame.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -3189,10 +3189,18 @@ def _setitem_array(self, key, value):
31893189
self.iloc[indexer] = value
31903190
else:
31913191
if isinstance(value, DataFrame):
3192-
if len(value.columns) != len(key):
3193-
raise ValueError("Columns must be same length as key")
3194-
for k1, k2 in zip(key, value.columns):
3195-
self[k1] = value[k2]
3192+
columns = value.columns
3193+
if len(columns) == len(key):
3194+
for k1, k2 in zip(key, columns):
3195+
self[k1] = value[k2]
3196+
elif columns.nlevels > 1 and len(columns.levels[0]) == len(key):
3197+
for k1, k2 in zip(key, columns.levels[0]):
3198+
self[k1] = value[k2]
3199+
else:
3200+
raise ValueError(
3201+
"Key must be same length as columns or top level of "
3202+
"MultiIndex"
3203+
)
31963204
else:
31973205
self.loc._ensure_listlike_indexer(key, axis=1, value=value)
31983206
indexer = self.loc._get_listlike_indexer(
@@ -3221,19 +3229,42 @@ def _setitem_frame(self, key, value):
32213229
def _set_item_frame_value(self, key, value: "DataFrame") -> None:
32223230
self._ensure_valid_index(value)
32233231

3224-
# align right-hand-side columns if self.columns
3225-
# is multi-index and self[key] is a sub-frame
3226-
if isinstance(self.columns, MultiIndex) and key in self.columns:
3227-
loc = self.columns.get_loc(key)
3228-
if isinstance(loc, (slice, Series, np.ndarray, Index)):
3229-
cols = maybe_droplevels(self.columns[loc], key)
3230-
if len(cols) and not cols.equals(value.columns):
3231-
value = value.reindex(cols, axis=1)
3232-
3233-
# now align rows
3234-
value = _reindex_for_setitem(value, self.index)
3235-
value = value.T
3236-
self._set_item_mgr(key, value)
3232+
# standardized key info
3233+
key_tup = key if isinstance(key, tuple) else (key,)
3234+
key_len = len(key_tup)
3235+
3236+
if key in self.columns or key_len == self.columns.nlevels:
3237+
# align right-hand-side columns if self.columns
3238+
# is multi-index and self[key] is a sub-frame
3239+
if isinstance(self.columns, MultiIndex) and key in self.columns:
3240+
loc = self.columns.get_loc(key)
3241+
if isinstance(loc, (slice, Series, np.ndarray, Index)):
3242+
cols = maybe_droplevels(self.columns[loc], key)
3243+
if len(cols) and not cols.equals(value.columns):
3244+
value = value.reindex(cols, axis=1)
3245+
3246+
# now align rows
3247+
value = _reindex_for_setitem(value, self.index)
3248+
value = value.T
3249+
self._set_item_mgr(key, value)
3250+
else:
3251+
if key_len + value.columns.nlevels != self.columns.nlevels:
3252+
raise ValueError(
3253+
"Must pass key/value pair that conforms with number of column "
3254+
"levels"
3255+
)
3256+
3257+
# fill out keys as necessary
3258+
if value.columns.nlevels > 1:
3259+
key_list = [key_tup + i for i in value.columns]
3260+
else:
3261+
key_list = [key_tup + (i,) for i in value.columns]
3262+
items = MultiIndex.from_tuples(key_list)
3263+
3264+
# align and append block
3265+
value = _reindex_for_setitem(value, self.index)
3266+
value = value.T
3267+
self._mgr.append_block(items, value)
32373268

32383269
def _iset_item_mgr(self, loc: int, value) -> None:
32393270
self._mgr.iset(loc, value)

pandas/core/internals/managers.py

+20
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,26 @@ def insert(self, loc: int, item: Label, value, allow_duplicates: bool = False):
12321232
stacklevel=5,
12331233
)
12341234

1235+
def append_block(self, items, values):
1236+
base, size = len(self.items), len(items)
1237+
1238+
new_axis = self.items.append(items)
1239+
block = make_block(
1240+
values=values, ndim=self.ndim, placement=slice(base, base + size)
1241+
)
1242+
1243+
blk_no = len(self.blocks)
1244+
self._blklocs = np.append(self.blklocs, range(size))
1245+
self._blknos = np.append(self.blknos, size * (blk_no,))
1246+
1247+
self.axes[0] = new_axis
1248+
self.blocks += (block,)
1249+
1250+
self._known_consolidated = False
1251+
1252+
if len(self.blocks) > 100:
1253+
self._consolidate_inplace()
1254+
12351255
def reindex_axis(
12361256
self,
12371257
new_index,

pandas/tests/frame/indexing/test_indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_setitem_list(self, float_frame):
119119
tm.assert_series_equal(float_frame["B"], data["A"], check_names=False)
120120
tm.assert_series_equal(float_frame["A"], data["B"], check_names=False)
121121

122-
msg = "Columns must be same length as key"
122+
msg = "Key must be same length as columns or top level of MultiIndex"
123123
with pytest.raises(ValueError, match=msg):
124124
data[["A"]] = float_frame[["A", "B"]]
125125
newcolumndata = range(len(data.index) - 1)

pandas/tests/indexing/multiindex/test_multiindex.py

+100
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import pandas._libs.index as _index
45
from pandas.errors import PerformanceWarning
@@ -93,3 +94,102 @@ def test_multiindex_with_datatime_level_preserves_freq(self):
9394
result = df.loc[0].index
9495
tm.assert_index_equal(result, dti)
9596
assert result.freq == dti.freq
97+
98+
def test_multiindex_get_loc_list_raises(self):
99+
# https://github.com/pandas-dev/pandas/issues/35878
100+
idx = pd.MultiIndex.from_tuples([("a", 1), ("b", 2)])
101+
msg = "unhashable type"
102+
with pytest.raises(TypeError, match=msg):
103+
idx.get_loc([])
104+
105+
def test_multiindex_frame_assign(self):
106+
df0 = pd.DataFrame({"a": [0, 1, 2, 3], "b": [3, 4, 5, 6]})
107+
df1 = pd.concat({"x": df0, "y": df0}, axis=1)
108+
df2 = pd.concat({"q": df1, "r": df1}, axis=1)
109+
110+
# level one assign
111+
result = df2.copy()
112+
result["m"] = result["q"] + result["r"]
113+
expected = pd.concat({"q": df1, "r": df1, "m": 2 * df1}, axis=1)
114+
tm.assert_frame_equal(result, expected)
115+
116+
# level one assign - multiple
117+
result = df2.copy()
118+
result[["m", "n"]] = 2 * result[["q", "r"]]
119+
expected = pd.concat({"q": df1, "r": df1, "m": 2 * df1, "n": 2 * df1}, axis=1)
120+
tm.assert_frame_equal(result, expected)
121+
122+
# level two assign
123+
result = df2.copy()
124+
result["m", "x"] = df2["q", "x"] + df2["q", "y"]
125+
expected = pd.concat(
126+
{"q": df1, "r": df1, "m": pd.concat({"x": 2 * df0}, axis=1)}, axis=1
127+
)
128+
tm.assert_frame_equal(result, expected)
129+
130+
# level two assign - multiple (seems like getitem is not caught up here)
131+
result = df2.copy()
132+
result[[("m", "x"), ("n", "y")]] = 2 * df2["q"]
133+
expected = pd.concat(
134+
{
135+
"q": df1,
136+
"r": df1,
137+
"m": pd.concat({"x": 2 * df0}, axis=1),
138+
"n": pd.concat({"y": 2 * df0}, axis=1),
139+
},
140+
axis=1,
141+
)
142+
tm.assert_frame_equal(result, expected)
143+
144+
# level three assign
145+
result = df2.copy()
146+
result["m", "x", "a"] = df2["q", "x", "a"] + df2["q", "x", "b"]
147+
expected = pd.concat(
148+
{
149+
"q": df1,
150+
"r": df1,
151+
"m": pd.concat(
152+
{"x": pd.concat({"a": df0["a"] + df0["b"]}, axis=1)}, axis=1
153+
),
154+
},
155+
axis=1,
156+
)
157+
tm.assert_frame_equal(result, expected)
158+
159+
# level three assign - multiple
160+
result = df2.copy()
161+
result[[("m", "x", "a"), ("n", "y", "b")]] = 2 * df2["q", "x"]
162+
expected = pd.concat(
163+
{
164+
"q": df1,
165+
"r": df1,
166+
"m": pd.concat({"x": pd.concat({"a": 2 * df0["a"]}, axis=1)}, axis=1),
167+
"n": pd.concat({"y": pd.concat({"b": 2 * df0["b"]}, axis=1)}, axis=1),
168+
},
169+
axis=1,
170+
)
171+
tm.assert_frame_equal(result, expected)
172+
173+
# invalid usage
174+
msg = "Must pass key/value pair that conforms with number of column levels"
175+
msg2 = "Wrong number of items passed 2, placement implies 1"
176+
177+
# too few levels at level one
178+
with pytest.raises(ValueError, match=msg):
179+
df2["m"] = df0
180+
181+
# too few levels at level two - this appears to be desired
182+
# with pytest.raises(ValueError, match=msg):
183+
# df2["m", "x"] = df0["a"]
184+
185+
# too many levels at level one
186+
with pytest.raises(ValueError, match=msg):
187+
df2["m"] = df2
188+
189+
# too many levels at level two
190+
with pytest.raises(ValueError, match=msg):
191+
df2["m", "x"] = df1
192+
193+
# too many levels at level three
194+
with pytest.raises(ValueError, match=msg2):
195+
df2["m", "x", "a"] = df0

0 commit comments

Comments
 (0)