Skip to content

Commit 6353525

Browse files
committed
multi-column explode
1 parent 3f67dc3 commit 6353525

File tree

2 files changed

+112
-24
lines changed

2 files changed

+112
-24
lines changed

pandas/core/frame.py

+68-23
Original file line numberDiff line numberDiff line change
@@ -8168,16 +8168,23 @@ def stack(self, level: Level = -1, dropna: bool = True):
81688168

81698169
return result.__finalize__(self, method="stack")
81708170

8171-
def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
8171+
def explode(
8172+
self,
8173+
column: str | tuple | list[str | tuple],
8174+
ignore_index: bool = False,
8175+
) -> DataFrame:
81728176
"""
81738177
Transform each element of a list-like to a row, replicating index values.
81748178
81758179
.. versionadded:: 0.25.0
81768180
81778181
Parameters
81788182
----------
8179-
column : str or tuple
8180-
Column to explode.
8183+
column : str or tuple or list thereof
8184+
Column(s) to explode.
8185+
For multiple columns, specify a non-empty list with each element
8186+
be str or tuple, and all specified columns their list-like data
8187+
on same row of the frame must have matching length.
81818188
ignore_index : bool, default False
81828189
If True, the resulting index will be labeled 0, 1, …, n - 1.
81838190
@@ -8192,7 +8199,10 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
81928199
Raises
81938200
------
81948201
ValueError :
8195-
if columns of the frame are not unique.
8202+
* If columns of the frame are not unique.
8203+
* If specified columns to explode is empty list.
8204+
* If specified columns to explode have not matching count of
8205+
elements rowwise in the frame.
81968206
81978207
See Also
81988208
--------
@@ -8211,32 +8221,67 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
82118221
82128222
Examples
82138223
--------
8214-
>>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1})
8224+
>>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]],
8225+
... 'B': 1,
8226+
... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]})
82158227
>>> df
8216-
A B
8217-
0 [1, 2, 3] 1
8218-
1 foo 1
8219-
2 [] 1
8220-
3 [3, 4] 1
8228+
A B C
8229+
0 [0, 1, 2] 1 [a, b, c]
8230+
1 foo 1 NaN
8231+
2 [] 1 []
8232+
3 [3, 4] 1 [d, e]
82218233
82228234
>>> df.explode('A')
8223-
A B
8224-
0 1 1
8225-
0 2 1
8226-
0 3 1
8227-
1 foo 1
8228-
2 NaN 1
8229-
3 3 1
8230-
3 4 1
8231-
"""
8232-
if not (is_scalar(column) or isinstance(column, tuple)):
8233-
raise ValueError("column must be a scalar")
8235+
A B C
8236+
0 0 1 [a, b, c]
8237+
0 1 1 [a, b, c]
8238+
0 2 1 [a, b, c]
8239+
1 foo 1 NaN
8240+
2 NaN 1 []
8241+
3 3 1 [d, e]
8242+
3 4 1 [d, e]
8243+
8244+
>>> df.explode(list('AC'))
8245+
A B C
8246+
0 0 1 a
8247+
0 1 1 b
8248+
0 2 1 c
8249+
1 foo 1 NaN
8250+
2 NaN 1 NaN
8251+
3 3 1 d
8252+
3 4 1 e
8253+
"""
82348254
if not self.columns.is_unique:
82358255
raise ValueError("columns must be unique")
82368256

8257+
columns: list[str | tuple]
8258+
if is_scalar(column) or isinstance(column, tuple):
8259+
# mypy: List item 0 has incompatible type "Union[str, Tuple[Any, ...],
8260+
# List[Union[str, Tuple[Any, ...]]]]"; expected
8261+
# "Union[str, Tuple[Any, ...]]"
8262+
columns = [column] # type: ignore[list-item]
8263+
elif isinstance(column, list) and all(
8264+
map(lambda c: is_scalar(c) or isinstance(c, tuple), column)
8265+
):
8266+
if len(column) == 0:
8267+
raise ValueError("column must be nonempty")
8268+
if len(column) > len(set(column)):
8269+
raise ValueError("column must be unique")
8270+
columns = column
8271+
else:
8272+
raise ValueError("column must be a scalar, tuple, or list thereof")
8273+
82378274
df = self.reset_index(drop=True)
8238-
result = df[column].explode()
8239-
result = df.drop([column], axis=1).join(result)
8275+
if len(columns) == 1:
8276+
result = df[column].explode()
8277+
else:
8278+
mylen = lambda x: len(x) if is_list_like(x) else -1
8279+
counts0 = self[columns[0]].apply(mylen)
8280+
for c in columns[1:]:
8281+
if not all(counts0 == self[c].apply(mylen)):
8282+
raise ValueError("columns must have matching element counts")
8283+
result = DataFrame({c: df[c].explode() for c in columns})
8284+
result = df.drop(columns, axis=1).join(result)
82408285
if ignore_index:
82418286
result.index = ibase.default_index(len(result))
82428287
else:

pandas/tests/frame/methods/test_explode.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,34 @@ def test_error():
99
df = pd.DataFrame(
1010
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
1111
)
12-
with pytest.raises(ValueError, match="column must be a scalar"):
12+
with pytest.raises(
13+
ValueError, match="column must be a scalar, tuple, or list thereof"
14+
):
15+
df.explode([list("AA")])
16+
17+
with pytest.raises(ValueError, match="column must be unique"):
1318
df.explode(list("AA"))
1419

1520
df.columns = list("AA")
1621
with pytest.raises(ValueError, match="columns must be unique"):
1722
df.explode("A")
1823

24+
# GH 39240
25+
df1 = df.assign(C=[["a", "b", "c"], "foo", [], ["d", "e", "f"]])
26+
df1.columns = list("ABC")
27+
with pytest.raises(ValueError, match="columns must have matching element counts"):
28+
df1.explode(list("AC"))
29+
30+
# GH 39240
31+
with pytest.raises(ValueError, match="column must be nonempty"):
32+
df1.explode([])
33+
34+
# GH 39240
35+
df2 = df.assign(C=[["a", "b", "c"], "foo", [], "d"])
36+
df2.columns = list("ABC")
37+
with pytest.raises(ValueError, match="columns must have matching element counts"):
38+
df2.explode(list("AC"))
39+
1940

2041
def test_basic():
2142
df = pd.DataFrame(
@@ -180,3 +201,25 @@ def test_explode_sets():
180201
result = df.explode(column="a").sort_values(by="a")
181202
expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1])
182203
tm.assert_frame_equal(result, expected)
204+
205+
206+
def test_multi_columns():
207+
# GH 39240
208+
df = pd.DataFrame(
209+
{
210+
"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")),
211+
"B": 1,
212+
"C": [["a", "b", "c"], "foo", [], ["d", "e"]],
213+
}
214+
)
215+
result = df.explode(list("AC"))
216+
expected = pd.DataFrame(
217+
{
218+
"A": pd.Series(
219+
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object
220+
),
221+
"B": 1,
222+
"C": ["a", "b", "c", "foo", np.nan, "d", "e"],
223+
}
224+
)
225+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)