Skip to content

Commit 8537396

Browse files
committed
multi-column explode
1 parent 4554635 commit 8537396

File tree

2 files changed

+82
-23
lines changed

2 files changed

+82
-23
lines changed

pandas/core/frame.py

+53-22
Original file line numberDiff line numberDiff line change
@@ -7925,7 +7925,9 @@ def stack(self, level: Level = -1, dropna: bool = True):
79257925
return result.__finalize__(self, method="stack")
79267926

79277927
def explode(
7928-
self, column: Union[str, Tuple], ignore_index: bool = False
7928+
self,
7929+
column: Union[str, Tuple, List[Union[str, Tuple]]],
7930+
ignore_index: bool = False
79297931
) -> DataFrame:
79307932
"""
79317933
Transform each element of a list-like to a row, replicating index values.
@@ -7934,8 +7936,8 @@ def explode(
79347936
79357937
Parameters
79367938
----------
7937-
column : str or tuple
7938-
Column to explode.
7939+
column : str or tuple or list thereof
7940+
Column(s) to explode.
79397941
ignore_index : bool, default False
79407942
If True, the resulting index will be labeled 0, 1, …, n - 1.
79417943
@@ -7969,32 +7971,61 @@ def explode(
79697971
79707972
Examples
79717973
--------
7972-
>>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1})
7974+
>>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]],
7975+
... 'B': 1,
7976+
... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]})
79737977
>>> df
7974-
A B
7975-
0 [1, 2, 3] 1
7976-
1 foo 1
7977-
2 [] 1
7978-
3 [3, 4] 1
7978+
A B C
7979+
0 [0, 1, 2] 1 [a, b, c]
7980+
1 foo 1 NaN
7981+
2 [] 1 []
7982+
3 [3, 4] 1 [d, e]
79797983
79807984
>>> df.explode('A')
7981-
A B
7982-
0 1 1
7983-
0 2 1
7984-
0 3 1
7985-
1 foo 1
7986-
2 NaN 1
7987-
3 3 1
7988-
3 4 1
7989-
"""
7990-
if not (is_scalar(column) or isinstance(column, tuple)):
7991-
raise ValueError("column must be a scalar")
7985+
A B C
7986+
0 0 1 [a, b, c]
7987+
0 1 1 [a, b, c]
7988+
0 2 1 [a, b, c]
7989+
1 foo 1 NaN
7990+
2 NaN 1 []
7991+
3 3 1 [d, e]
7992+
3 4 1 [d, e]
7993+
7994+
>>> df.explode(list('AC'))
7995+
A B C
7996+
0 0 1 a
7997+
0 1 1 b
7998+
0 2 1 c
7999+
1 foo 1 NaN
8000+
2 NaN 1 NaN
8001+
3 3 1 d
8002+
3 4 1 e
8003+
"""
79928004
if not self.columns.is_unique:
79938005
raise ValueError("columns must be unique")
8006+
if (is_scalar(column) or isinstance(column, tuple)):
8007+
columns = [column]
8008+
elif (isinstance(column, list) and
8009+
all(map(lambda c: is_scalar(c) or isinstance(c, tuple),
8010+
column))):
8011+
if len(column) > len(set(column)):
8012+
raise ValueError("column must be unique")
8013+
# mypy: Incompatible types in assignment (expression has type
8014+
# "List[Union[str, Tuple[Any, ...]]]", variable has type
8015+
# "List[Union[str, Tuple[Any, ...], List[Union[str, Tuple[Any, ...]]]]]")
8016+
columns = column # type: ignore[assignment]
8017+
else:
8018+
raise ValueError("column must be a scalar, tuple, or list thereof")
8019+
8020+
mylen = lambda x: len(x) if is_list_like(x) else -1
8021+
counts0 = self[columns[0]].apply(mylen)
8022+
for c in columns[1:]:
8023+
if not all(counts0 == self[c].apply(mylen)):
8024+
raise ValueError("columns must have matching element counts")
79948025

79958026
df = self.reset_index(drop=True)
7996-
result = df[column].explode()
7997-
result = df.drop([column], axis=1).join(result)
8027+
result = DataFrame({c:df[c].explode() for c in columns})
8028+
result = df.drop(columns, axis=1).join(result)
79988029
if ignore_index:
79998030
result.index = ibase.default_index(len(result))
80008031
else:

pandas/tests/frame/methods/test_explode.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,21 @@ def test_error():
99
df = pd.DataFrame(
1010
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
1111
)
12-
with pytest.raises(ValueError, match="column must be a scalar"):
12+
with pytest.raises(ValueError, match="column must be a scalar, tuple, or list thereof"):
13+
df.explode([list("AA")])
14+
15+
with pytest.raises(ValueError, match="column must be unique"):
1316
df.explode(list("AA"))
1417

1518
df.columns = list("AA")
1619
with pytest.raises(ValueError, match="columns must be unique"):
1720
df.explode("A")
1821

22+
df1 = df.assign(C=[["a", "b", "c"], "foo", [], ["d", "e", "f"]])
23+
df1.columns = list("ABC")
24+
with pytest.raises(ValueError, match="columns must have matching element counts"):
25+
df1.explode(list("AC"))
26+
1927

2028
def test_basic():
2129
df = pd.DataFrame(
@@ -180,3 +188,23 @@ def test_explode_sets():
180188
result = df.explode(column="a").sort_values(by="a")
181189
expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1])
182190
tm.assert_frame_equal(result, expected)
191+
192+
193+
def test_multi_columns():
194+
df = pd.DataFrame(
195+
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")),
196+
"B": 1,
197+
"C": [["a", "b", "c"], "foo", [], ["d", "e"]]}
198+
)
199+
result = df.explode(list("AC"))
200+
expected = pd.DataFrame(
201+
{
202+
"A": pd.Series(
203+
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object
204+
),
205+
"B": 1,
206+
"C": ["a", "b", "c", "foo", np.nan, "d", "e"]
207+
}
208+
)
209+
tm.assert_frame_equal(result, expected)
210+

0 commit comments

Comments
 (0)