Skip to content

Fix Issue #34923: Inferred dtype at the end of df explode method #35011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
30 changes: 30 additions & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,36 @@ apply and applymap on ``DataFrame`` evaluates first row/column only once

df.apply(func, axis=1)


.. _whatsnew_110.api_breaking.explode_infer_dtype:

Infer dtypes in explode method for Dataframe and Series
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Using :meth:`DataFrame.explode` and :meth:`Series.explode` would always return an object for the column being exploded. Now the dtype of the column would be inferred and returned accordingly. (:issue:`34923`)

.. ipython:: python

s = pd.Series([1, 2, 3])
df = pd.DataFrame({'A': [s, s, s, s], 'B': 1})

*Previous behavior*:

.. code-block:: ipython

In [3]: df.explode("A").dtypes
Out[3]:
A object
B int64
dtype: object

*New behavior*:

.. ipython:: python

df.explode("A").dtypes


Increased minimum versions for dependencies
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7070,6 +7070,7 @@ def explode(
else:
result.index = self.index.take(result.index)
result = result.reindex(columns=self.columns, copy=False)
result = result.infer_objects()

return result

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3843,7 +3843,7 @@ def explode(self, ignore_index: bool = False) -> "Series":
else:
index = self.index.repeat(counts)

result = self._constructor(values, index=index, name=self.name)
result = self._constructor(values, index=index, name=self.name).infer_objects()

return result

Expand Down
25 changes: 20 additions & 5 deletions pandas/tests/frame/methods/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_basic():
expected = pd.DataFrame(
{
"A": pd.Series(
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=np.float64
),
"B": 1,
}
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_multi_index_rows():
("b", 2),
]
),
dtype=object,
dtype=np.float64,
),
"B": 1,
}
Expand All @@ -74,7 +74,7 @@ def test_multi_index_columns():
("A", 1): pd.Series(
[0, 1, 2, np.nan, np.nan, 3, 4],
index=pd.Index([0, 0, 0, 1, 2, 3, 3]),
dtype=object,
dtype=np.float64,
),
("A", 2): 1,
}
Expand All @@ -93,7 +93,7 @@ def test_usecase():
expected = pd.DataFrame(
{
"A": [11, 11, 11, 11, 11, 22, 22, 22],
"B": np.array([0, 1, 2, 3, 4, 0, 1, 2], dtype=object),
"B": np.array([0, 1, 2, 3, 4, 0, 1, 2], dtype=np.int64),
"C": [10, 10, 10, 10, 10, 20, 20, 20],
},
columns=list("ABC"),
Expand Down Expand Up @@ -160,7 +160,22 @@ def test_duplicate_index(input_dict, input_index, expected_dict, expected_index)
# GH 28005
df = pd.DataFrame(input_dict, index=input_index)
result = df.explode("col1")
expected = pd.DataFrame(expected_dict, index=expected_index, dtype=object)
expected = pd.DataFrame(expected_dict, index=expected_index, dtype=np.int64)
tm.assert_frame_equal(result, expected)


def test_inferred_dtype():
# GH 34923
s = pd.Series([1, None, 3])
df = pd.DataFrame({"A": [s, s], "B": 1})
result = df.explode("A")
expected = pd.DataFrame(
{
"A": np.array([1, np.nan, 3, 1, np.nan, 3], dtype=np.float64),
"B": np.array([1, 1, 1, 1, 1, 1], dtype=np.int64),
},
index=[0, 0, 0, 1, 1, 1],
)
tm.assert_frame_equal(result, expected)


Expand Down
11 changes: 7 additions & 4 deletions pandas/tests/series/methods/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def test_basic():
s = pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd"), name="foo")
result = s.explode()
expected = pd.Series(
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object, name="foo"
[0, 1, 2, np.nan, np.nan, 3, 4],
index=list("aaabcdd"),
dtype=np.float64,
name="foo",
)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -54,7 +57,7 @@ def test_multi_index():
names=["foo", "bar"],
)
expected = pd.Series(
[0, 1, 2, np.nan, np.nan, 3, 4], index=index, dtype=object, name="foo"
[0, 1, 2, np.nan, np.nan, 3, 4], index=index, dtype=np.float64, name="foo"
)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -116,13 +119,13 @@ def test_duplicate_index():
# GH 28005
s = pd.Series([[1, 2], [3, 4]], index=[0, 0])
result = s.explode()
expected = pd.Series([1, 2, 3, 4], index=[0, 0, 0, 0], dtype=object)
expected = pd.Series([1, 2, 3, 4], index=[0, 0, 0, 0], dtype=np.int64)
tm.assert_series_equal(result, expected)


def test_ignore_index():
# GH 34932
s = pd.Series([[1, 2], [3, 4]])
result = s.explode(ignore_index=True)
expected = pd.Series([1, 2, 3, 4], index=[0, 1, 2, 3], dtype=object)
expected = pd.Series([1, 2, 3, 4], index=[0, 1, 2, 3], dtype=np.int64)
tm.assert_series_equal(result, expected)