Skip to content

Commit 8bd59de

Browse files
committed
multi-column explode
1 parent 5d20815 commit 8bd59de

File tree

2 files changed

+111
-24
lines changed

2 files changed

+111
-24
lines changed

pandas/core/frame.py

+67-23
Original file line numberDiff line numberDiff line change
@@ -7910,16 +7910,23 @@ def stack(self, level: Level = -1, dropna: bool = True):
79107910

79117911
return result.__finalize__(self, method="stack")
79127912

7913-
def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
7913+
def explode(
7914+
self,
7915+
column: str | tuple | list[str | tuple],
7916+
ignore_index: bool = False,
7917+
) -> DataFrame:
79147918
"""
79157919
Transform each element of a list-like to a row, replicating index values.
79167920
79177921
.. versionadded:: 0.25.0
79187922
79197923
Parameters
79207924
----------
7921-
column : str or tuple
7922-
Column to explode.
7925+
column : str or tuple or list thereof
7926+
Column(s) to explode.
7927+
For multiple columns, specify a non-empty list with each element
7928+
be str or tuple, and all specified columns their list-like data
7929+
on same row of the frame must have matching length.
79237930
ignore_index : bool, default False
79247931
If True, the resulting index will be labeled 0, 1, …, n - 1.
79257932
@@ -7934,7 +7941,10 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
79347941
Raises
79357942
------
79367943
ValueError :
7937-
if columns of the frame are not unique.
7944+
* If columns of the frame are not unique.
7945+
* If specified columns to explode is empty list.
7946+
* If specified columns to explode have not matching count of
7947+
elements rowwise in the frame.
79387948
79397949
See Also
79407950
--------
@@ -7953,32 +7963,66 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
79537963
79547964
Examples
79557965
--------
7956-
>>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1})
7966+
>>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]],
7967+
... 'B': 1,
7968+
... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]})
79577969
>>> df
7958-
A B
7959-
0 [1, 2, 3] 1
7960-
1 foo 1
7961-
2 [] 1
7962-
3 [3, 4] 1
7970+
A B C
7971+
0 [0, 1, 2] 1 [a, b, c]
7972+
1 foo 1 NaN
7973+
2 [] 1 []
7974+
3 [3, 4] 1 [d, e]
79637975
79647976
>>> df.explode('A')
7965-
A B
7966-
0 1 1
7967-
0 2 1
7968-
0 3 1
7969-
1 foo 1
7970-
2 NaN 1
7971-
3 3 1
7972-
3 4 1
7973-
"""
7974-
if not (is_scalar(column) or isinstance(column, tuple)):
7975-
raise ValueError("column must be a scalar")
7977+
A B C
7978+
0 0 1 [a, b, c]
7979+
0 1 1 [a, b, c]
7980+
0 2 1 [a, b, c]
7981+
1 foo 1 NaN
7982+
2 NaN 1 []
7983+
3 3 1 [d, e]
7984+
3 4 1 [d, e]
7985+
7986+
>>> df.explode(list('AC'))
7987+
A B C
7988+
0 0 1 a
7989+
0 1 1 b
7990+
0 2 1 c
7991+
1 foo 1 NaN
7992+
2 NaN 1 NaN
7993+
3 3 1 d
7994+
3 4 1 e
7995+
"""
79767996
if not self.columns.is_unique:
79777997
raise ValueError("columns must be unique")
79787998

7999+
columns: list[str | tuple]
8000+
if is_scalar(column) or isinstance(column, tuple):
8001+
# mypy: List item 0 has incompatible type "Union[str, Tuple[Any, ...],
8002+
# List[Union[str, Tuple[Any, ...]]]]"; expected "Union[str, Tuple[Any, ...]]"
8003+
columns = [column] # type: ignore[list-item]
8004+
elif isinstance(column, list) and all(
8005+
map(lambda c: is_scalar(c) or isinstance(c, tuple), column)
8006+
):
8007+
if len(column) == 0:
8008+
raise ValueError("column must be nonempty")
8009+
if len(column) > len(set(column)):
8010+
raise ValueError("column must be unique")
8011+
columns = column
8012+
else:
8013+
raise ValueError("column must be a scalar, tuple, or list thereof")
8014+
79798015
df = self.reset_index(drop=True)
7980-
result = df[column].explode()
7981-
result = df.drop([column], axis=1).join(result)
8016+
if len(columns) == 1:
8017+
result = df[column].explode()
8018+
else:
8019+
mylen = lambda x: len(x) if is_list_like(x) else -1
8020+
counts0 = self[columns[0]].apply(mylen)
8021+
for c in columns[1:]:
8022+
if not all(counts0 == self[c].apply(mylen)):
8023+
raise ValueError("columns must have matching element counts")
8024+
result = DataFrame({c: df[c].explode() for c in columns})
8025+
result = df.drop(columns, axis=1).join(result)
79828026
if ignore_index:
79838027
result.index = ibase.default_index(len(result))
79848028
else:

pandas/tests/frame/methods/test_explode.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,34 @@ def test_error():
99
df = pd.DataFrame(
1010
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
1111
)
12-
with pytest.raises(ValueError, match="column must be a scalar"):
12+
with pytest.raises(
13+
ValueError, match="column must be a scalar, tuple, or list thereof"
14+
):
15+
df.explode([list("AA")])
16+
17+
with pytest.raises(ValueError, match="column must be unique"):
1318
df.explode(list("AA"))
1419

1520
df.columns = list("AA")
1621
with pytest.raises(ValueError, match="columns must be unique"):
1722
df.explode("A")
1823

24+
# GH 39240
25+
df1 = df.assign(C=[["a", "b", "c"], "foo", [], ["d", "e", "f"]])
26+
df1.columns = list("ABC")
27+
with pytest.raises(ValueError, match="columns must have matching element counts"):
28+
df1.explode(list("AC"))
29+
30+
# GH 39240
31+
with pytest.raises(ValueError, match="column must be nonempty"):
32+
df1.explode([])
33+
34+
# GH 39240
35+
df2 = df.assign(C=[["a", "b", "c"], "foo", [], "d"])
36+
df2.columns = list("ABC")
37+
with pytest.raises(ValueError, match="columns must have matching element counts"):
38+
df2.explode(list("AC"))
39+
1940

2041
def test_basic():
2142
df = pd.DataFrame(
@@ -180,3 +201,25 @@ def test_explode_sets():
180201
result = df.explode(column="a").sort_values(by="a")
181202
expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1])
182203
tm.assert_frame_equal(result, expected)
204+
205+
206+
def test_multi_columns():
207+
# GH 39240
208+
df = pd.DataFrame(
209+
{
210+
"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")),
211+
"B": 1,
212+
"C": [["a", "b", "c"], "foo", [], ["d", "e"]],
213+
}
214+
)
215+
result = df.explode(list("AC"))
216+
expected = pd.DataFrame(
217+
{
218+
"A": pd.Series(
219+
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object
220+
),
221+
"B": 1,
222+
"C": ["a", "b", "c", "foo", np.nan, "d", "e"],
223+
}
224+
)
225+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)