Skip to content

Commit f55ad59

Browse files
committed
EHN: multi-column explode (#39240)
1 parent 3659eda commit f55ad59

File tree

3 files changed

+167
-25
lines changed

3 files changed

+167
-25
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ enhancement2
2929

3030
Other enhancements
3131
^^^^^^^^^^^^^^^^^^
32-
-
32+
- :meth:`DataFrame.explode` now supports exploding multiple columns. Its ``column`` argument now also accepts a list of str or tuples for exploding on multiple columns at the same time (:issue:`39240`)
3333
-
3434

3535
.. ---------------------------------------------------------------------------

pandas/core/frame.py

+74-23
Original file line numberDiff line numberDiff line change
@@ -8151,16 +8151,27 @@ def stack(self, level: Level = -1, dropna: bool = True):
81518151

81528152
return result.__finalize__(self, method="stack")
81538153

8154-
def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
8154+
def explode(
8155+
self,
8156+
column: str | tuple | list[str | tuple],
8157+
ignore_index: bool = False,
8158+
) -> DataFrame:
81558159
"""
81568160
Transform each element of a list-like to a row, replicating index values.
81578161
81588162
.. versionadded:: 0.25.0
81598163
81608164
Parameters
81618165
----------
8162-
column : str or tuple
8163-
Column to explode.
8166+
column : str or tuple or list thereof
8167+
Column(s) to explode.
8168+
For multiple columns, specify a non-empty list with each element
8169+
be str or tuple, and all specified columns their list-like data
8170+
on same row of the frame must have matching length.
8171+
8172+
.. versionadded:: 1.4.0
8173+
Multi-column explode
8174+
81648175
ignore_index : bool, default False
81658176
If True, the resulting index will be labeled 0, 1, …, n - 1.
81668177
@@ -8175,7 +8186,10 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
81758186
Raises
81768187
------
81778188
ValueError :
8178-
if columns of the frame are not unique.
8189+
* If columns of the frame are not unique.
8190+
* If specified columns to explode is empty list.
8191+
* If specified columns to explode have not matching count of
8192+
elements rowwise in the frame.
81798193
81808194
See Also
81818195
--------
@@ -8194,32 +8208,69 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
81948208
81958209
Examples
81968210
--------
8197-
>>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1})
8211+
>>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]],
8212+
... 'B': 1,
8213+
... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]})
81988214
>>> df
8199-
A B
8200-
0 [1, 2, 3] 1
8201-
1 foo 1
8202-
2 [] 1
8203-
3 [3, 4] 1
8215+
A B C
8216+
0 [0, 1, 2] 1 [a, b, c]
8217+
1 foo 1 NaN
8218+
2 [] 1 []
8219+
3 [3, 4] 1 [d, e]
8220+
8221+
Single-column explode.
82048222
82058223
>>> df.explode('A')
8206-
A B
8207-
0 1 1
8208-
0 2 1
8209-
0 3 1
8210-
1 foo 1
8211-
2 NaN 1
8212-
3 3 1
8213-
3 4 1
8214-
"""
8215-
if not (is_scalar(column) or isinstance(column, tuple)):
8216-
raise ValueError("column must be a scalar")
8224+
A B C
8225+
0 0 1 [a, b, c]
8226+
0 1 1 [a, b, c]
8227+
0 2 1 [a, b, c]
8228+
1 foo 1 NaN
8229+
2 NaN 1 []
8230+
3 3 1 [d, e]
8231+
3 4 1 [d, e]
8232+
8233+
Multi-column explode.
8234+
8235+
>>> df.explode(list('AC'))
8236+
A B C
8237+
0 0 1 a
8238+
0 1 1 b
8239+
0 2 1 c
8240+
1 foo 1 NaN
8241+
2 NaN 1 NaN
8242+
3 3 1 d
8243+
3 4 1 e
8244+
"""
82178245
if not self.columns.is_unique:
82188246
raise ValueError("columns must be unique")
82198247

8248+
columns: list[str | tuple]
8249+
if is_scalar(column) or isinstance(column, tuple):
8250+
assert isinstance(column, (str, tuple))
8251+
columns = [column]
8252+
elif isinstance(column, list) and all(
8253+
map(lambda c: is_scalar(c) or isinstance(c, tuple), column)
8254+
):
8255+
if not column:
8256+
raise ValueError("column must be nonempty")
8257+
if len(column) > len(set(column)):
8258+
raise ValueError("column must be unique")
8259+
columns = column
8260+
else:
8261+
raise ValueError("column must be a scalar, tuple, or list thereof")
8262+
82208263
df = self.reset_index(drop=True)
8221-
result = df[column].explode()
8222-
result = df.drop([column], axis=1).join(result)
8264+
if len(columns) == 1:
8265+
result = df[columns[0]].explode()
8266+
else:
8267+
mylen = lambda x: len(x) if is_list_like(x) else -1
8268+
counts0 = self[columns[0]].apply(mylen)
8269+
for c in columns[1:]:
8270+
if not all(counts0 == self[c].apply(mylen)):
8271+
raise ValueError("columns must have matching element counts")
8272+
result = DataFrame({c: df[c].explode() for c in columns})
8273+
result = df.drop(columns, axis=1).join(result)
82238274
if ignore_index:
82248275
result.index = ibase.default_index(len(result))
82258276
else:

pandas/tests/frame/methods/test_explode.py

+92-1
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,50 @@ def test_error():
99
df = pd.DataFrame(
1010
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
1111
)
12-
with pytest.raises(ValueError, match="column must be a scalar"):
12+
with pytest.raises(
13+
ValueError, match="column must be a scalar, tuple, or list thereof"
14+
):
15+
df.explode([list("AA")])
16+
17+
with pytest.raises(ValueError, match="column must be unique"):
1318
df.explode(list("AA"))
1419

1520
df.columns = list("AA")
1621
with pytest.raises(ValueError, match="columns must be unique"):
1722
df.explode("A")
1823

1924

25+
@pytest.mark.parametrize(
26+
"input_subset, error_message",
27+
[
28+
(
29+
list("AC"),
30+
"columns must have matching element counts",
31+
),
32+
(
33+
[],
34+
"column must be nonempty",
35+
),
36+
(
37+
list("AC"),
38+
"columns must have matching element counts",
39+
),
40+
],
41+
)
42+
def test_error_multi_columns(input_subset, error_message):
43+
# GH 39240
44+
df = pd.DataFrame(
45+
{
46+
"A": [[0, 1, 2], np.nan, [], (3, 4)],
47+
"B": 1,
48+
"C": [["a", "b", "c"], "foo", [], ["d", "e", "f"]],
49+
},
50+
index=list("abcd"),
51+
)
52+
with pytest.raises(ValueError, match=error_message):
53+
df.explode(input_subset)
54+
55+
2056
def test_basic():
2157
df = pd.DataFrame(
2258
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
@@ -180,3 +216,58 @@ def test_explode_sets():
180216
result = df.explode(column="a").sort_values(by="a")
181217
expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1])
182218
tm.assert_frame_equal(result, expected)
219+
220+
221+
@pytest.mark.parametrize(
222+
"input_subset, expected_dict, expected_index",
223+
[
224+
(
225+
list("AC"),
226+
{
227+
"A": pd.Series(
228+
[0, 1, 2, np.nan, np.nan, 3, 4, np.nan],
229+
index=list("aaabcdde"),
230+
dtype=object,
231+
),
232+
"B": 1,
233+
"C": ["a", "b", "c", "foo", np.nan, "d", "e", np.nan],
234+
},
235+
list("aaabcdde"),
236+
),
237+
(
238+
list("A"),
239+
{
240+
"A": pd.Series(
241+
[0, 1, 2, np.nan, np.nan, 3, 4, np.nan],
242+
index=list("aaabcdde"),
243+
dtype=object,
244+
),
245+
"B": 1,
246+
"C": [
247+
["a", "b", "c"],
248+
["a", "b", "c"],
249+
["a", "b", "c"],
250+
"foo",
251+
[],
252+
["d", "e"],
253+
["d", "e"],
254+
np.nan,
255+
],
256+
},
257+
list("aaabcdde"),
258+
),
259+
],
260+
)
261+
def test_multi_columns(input_subset, expected_dict, expected_index):
262+
# GH 39240
263+
df = pd.DataFrame(
264+
{
265+
"A": [[0, 1, 2], np.nan, [], (3, 4), np.nan],
266+
"B": 1,
267+
"C": [["a", "b", "c"], "foo", [], ["d", "e"], np.nan],
268+
},
269+
index=list("abcde"),
270+
)
271+
result = df.explode(input_subset)
272+
expected = pd.DataFrame(expected_dict, expected_index)
273+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)