diff --git a/pandas/core/frame.py b/pandas/core/frame.py index f1ed3a125f60c..72ea208a8b801 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -6205,7 +6205,7 @@ def stack(self, level=-1, dropna=True): else: return stack(self, level, dropna=dropna) - def explode(self, column: Union[str, Tuple]) -> "DataFrame": + def explode(self, columns: Union[str, List[str]]) -> "DataFrame": """ Transform each element of a list-like to a row, replicating index values. @@ -6213,8 +6213,8 @@ def explode(self, column: Union[str, Tuple]) -> "DataFrame": Parameters ---------- - column : str or tuple - Column to explode. + columns : str or list + the column(s) to be exploded Returns ------- @@ -6260,19 +6260,63 @@ def explode(self, column: Union[str, Tuple]) -> "DataFrame": 2 NaN 1 3 3 1 3 4 1 - """ - if not (is_scalar(column) or isinstance(column, tuple)): - raise ValueError("column must be a scalar") + >>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], + 'B': 1, + 'C': [[7,8,9],'bar',[],[8,7]]}) + >>> df + A B C + 0 [1, 2, 3] 1 [7, 8, 9] + 1 foo 1 bar + 2 [] 1 [] + 3 [3, 4] 1 [8, 7] + + >>> df.explode(['A','C']) + B A C + 0 1 1 7 + 0 1 2 8 + 0 1 3 9 + 1 1 foo bar + 2 1 NaN NaN + 3 1 3 8 + 3 1 4 7 + """ + + # Validate data if not self.columns.is_unique: raise ValueError("columns must be unique") - result = self[column].explode() - return ( - self.drop([column], axis=1) - .join(result) - .reindex(columns=self.columns, copy=False) - ) + if isinstance(columns, str): + columns = [columns] + + if not isinstance(columns, list): + raise TypeError("columns value not list or sting") + + if not all([c in self.columns for c in columns]): + raise ValueError("column name(s) not in index") + + tmp = self.iloc[0:0, 0:0].copy() # creates empty temp df + lengths_equal = [] + + for row in self[columns].iterrows(): + # converts non-lists into 1 element lists so len() is valid + r = row[1].apply(lambda x: x if type(x) in (list, tuple) else [x]) + + # make sure all lists in the same record are the same length + row_is_ok = len(set([len(r[c]) for c in columns])) == 1 + lengths_equal.append(row_is_ok) + + # Explode all columns if lengths match + if all(lengths_equal): + for c in columns: + tmp[c] = self[c].explode() + else: + e = "Elements in `columns` do not have equal length in the same row" + raise ValueError(e) + + # join in exploded columns + results = self.drop(columns, axis=1).join(tmp) + return(results) def unstack(self, level=-1, fill_value=None): """