Skip to content

Commit 5608c83

Browse files
committed
allow for top and mid-level assignment to DataFrames with MultIndex columns
1 parent 4d3b197 commit 5608c83

File tree

4 files changed

+130
-5
lines changed

4 files changed

+130
-5
lines changed

pandas/core/frame.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -3057,10 +3057,15 @@ def __setitem__(self, key, value):
30573057
# to a slice for partial-string date indexing
30583058
return self._setitem_slice(indexer, value)
30593059

3060+
# mimic getitem behavior
3061+
is_single_key = isinstance(key, tuple) or not is_list_like(key)
3062+
30603063
if isinstance(key, DataFrame) or getattr(key, "ndim", None) == 2:
30613064
self._setitem_frame(key, value)
30623065
elif isinstance(key, (Series, np.ndarray, list, Index)):
30633066
self._setitem_array(key, value)
3067+
elif is_single_key and self.columns.nlevels > 1:
3068+
return self._setitem_multilevel(key, value)
30643069
else:
30653070
# set column
30663071
self._set_item(key, value)
@@ -3085,10 +3090,18 @@ def _setitem_array(self, key, value):
30853090
self.iloc._setitem_with_indexer(indexer, value)
30863091
else:
30873092
if isinstance(value, DataFrame):
3088-
if len(value.columns) != len(key):
3089-
raise ValueError("Columns must be same length as key")
3090-
for k1, k2 in zip(key, value.columns):
3091-
self[k1] = value[k2]
3093+
columns = value.columns
3094+
if len(columns) == len(key):
3095+
for k1, k2 in zip(key, columns):
3096+
self[k1] = value[k2]
3097+
elif columns.nlevels > 1 and len(columns.levels[0]) == len(key):
3098+
for k1, k2 in zip(key, columns.levels[0]):
3099+
self[k1] = value[k2]
3100+
else:
3101+
raise ValueError(
3102+
"Key must be same length as columns or top level of "
3103+
"MultiIndex"
3104+
)
30923105
else:
30933106
self.loc._ensure_listlike_indexer(key, axis=1)
30943107
indexer = self.loc._get_listlike_indexer(
@@ -3114,6 +3127,30 @@ def _setitem_frame(self, key, value):
31143127
self._check_setitem_copy()
31153128
self._where(-key, value, inplace=True)
31163129

3130+
def _setitem_multilevel(self, key, value):
3131+
# self.columns is a MultiIndex
3132+
if key in self.columns:
3133+
self._set_item(key, value)
3134+
else:
3135+
if not isinstance(key, tuple):
3136+
key = (key,)
3137+
if isinstance(value, DataFrame):
3138+
if len(key) + value.columns.nlevels != self.columns.nlevels:
3139+
raise TypeError(
3140+
"Must pass key/value pair that conforms with number of column "
3141+
"levels"
3142+
)
3143+
if value.columns.nlevels > 1:
3144+
items = MultiIndex.from_tuples([key + i for i in value.columns])
3145+
else:
3146+
items = MultiIndex.from_tuples([key + (i,) for i in value.columns])
3147+
else:
3148+
if len(key) < self.columns.nlevels:
3149+
key = key + ("",) * (self.columns.nlevels - len(key))
3150+
items = MultiIndex.from_tuples([key])
3151+
value = self._sanitize_column(key, value)
3152+
self._mgr.append_block(items, value)
3153+
31173154
def _iset_item(self, loc: int, value):
31183155
self._ensure_valid_index(value)
31193156

pandas/core/internals/managers.py

+20
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,26 @@ def insert(self, loc: int, item: Label, value, allow_duplicates: bool = False):
12041204
if len(self.blocks) > 100:
12051205
self._consolidate_inplace()
12061206

1207+
def append_block(self, items, values):
1208+
base, size = len(self.items), len(items)
1209+
1210+
new_axis = self.items.append(items)
1211+
block = make_block(
1212+
values=values, ndim=self.ndim, placement=slice(base, base + size)
1213+
)
1214+
1215+
blk_no = len(self.blocks)
1216+
self._blklocs = np.append(self.blklocs, range(size))
1217+
self._blknos = np.append(self.blknos, size * (blk_no,))
1218+
1219+
self.axes[0] = new_axis
1220+
self.blocks += (block,)
1221+
1222+
self._known_consolidated = False
1223+
1224+
if len(self.blocks) > 100:
1225+
self._consolidate_inplace()
1226+
12071227
def reindex_axis(
12081228
self,
12091229
new_index,

pandas/tests/frame/indexing/test_indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_setitem_list(self, float_frame):
157157
tm.assert_series_equal(float_frame["B"], data["A"], check_names=False)
158158
tm.assert_series_equal(float_frame["A"], data["B"], check_names=False)
159159

160-
msg = "Columns must be same length as key"
160+
msg = "Key must be same length as columns or top level of MultiIndex"
161161
with pytest.raises(ValueError, match=msg):
162162
data[["A"]] = float_frame[["A", "B"]]
163163
newcolumndata = range(len(data.index) - 1)

pandas/tests/indexing/multiindex/test_multiindex.py

+68
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,71 @@ def test_multiindex_get_loc_list_raises(self):
9191
msg = "unhashable type"
9292
with pytest.raises(TypeError, match=msg):
9393
idx.get_loc([])
94+
95+
def test_multiindex_frame_assign(self):
96+
df0 = pd.DataFrame({"a": [0, 1, 2, 3], "b": [3, 4, 5, 6]})
97+
df1 = pd.concat({"x": df0, "y": df0}, axis=1)
98+
df2 = pd.concat({"q": df1, "r": df1}, axis=1)
99+
100+
# level one assign
101+
result = df2.copy()
102+
result["m"] = result["q"] + result["r"]
103+
expected = pd.concat({"q": df1, "r": df1, "m": 2 * df1}, axis=1)
104+
tm.assert_frame_equal(result, expected)
105+
106+
# level one assign - multiple
107+
result = df2.copy()
108+
result[["m", "n"]] = 2 * result[["q", "r"]]
109+
expected = pd.concat({"q": df1, "r": df1, "m": 2 * df1, "n": 2 * df1}, axis=1)
110+
tm.assert_frame_equal(result, expected)
111+
112+
# level two assign
113+
result = df2.copy()
114+
result["m", "x"] = df2["q", "x"] + df2["q", "y"]
115+
expected = pd.concat(
116+
{"q": df1, "r": df1, "m": pd.concat({"x": 2 * df0}, axis=1)}, axis=1
117+
)
118+
tm.assert_frame_equal(result, expected)
119+
120+
# level two assign - multiple (seems like getitem is not caught up here)
121+
result = df2.copy()
122+
result[[("m", "x"), ("n", "y")]] = 2 * df2["q"]
123+
expected = pd.concat(
124+
{
125+
"q": df1,
126+
"r": df1,
127+
"m": pd.concat({"x": 2 * df0}, axis=1),
128+
"n": pd.concat({"y": 2 * df0}, axis=1),
129+
},
130+
axis=1,
131+
)
132+
tm.assert_frame_equal(result, expected)
133+
134+
# level three assign
135+
result = df2.copy()
136+
result["m", "x", "a"] = df2["q", "x", "a"] + df2["q", "x", "b"]
137+
expected = pd.concat(
138+
{
139+
"q": df1,
140+
"r": df1,
141+
"m": pd.concat(
142+
{"x": pd.concat({"a": df0["a"] + df0["b"]}, axis=1)}, axis=1
143+
),
144+
},
145+
axis=1,
146+
)
147+
tm.assert_frame_equal(result, expected)
148+
149+
# level three assign - multiple
150+
result = df2.copy()
151+
result[[("m", "x", "a"), ("n", "y", "b")]] = 2 * df2["q", "x"]
152+
expected = pd.concat(
153+
{
154+
"q": df1,
155+
"r": df1,
156+
"m": pd.concat({"x": pd.concat({"a": 2 * df0["a"]}, axis=1)}, axis=1),
157+
"n": pd.concat({"y": pd.concat({"b": 2 * df0["b"]}, axis=1)}, axis=1),
158+
},
159+
axis=1,
160+
)
161+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)