Skip to content

Commit bea26f7

Browse files
author
Santhosh18
committed
Modified test cases and added detailed explanation in v1.1.0.rst
1 parent 1dfec9e commit bea26f7

File tree

4 files changed

+58
-11
lines changed

4 files changed

+58
-11
lines changed

doc/source/whatsnew/v1.1.0.rst

+32
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,38 @@ apply and applymap on ``DataFrame`` evaluates first row/column only once
661661
662662
df.apply(func, axis=1)
663663
664+
.. _whatsnew_110.api_breaking.explode_infer_dtype:
665+
666+
Infer dtypes in explode method for Dataframe and Series
667+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
668+
669+
Using :meth:`DataFrame.explode` and :meth:`Series.explode` would always return an object for the column being exploded. Now the dtype of the column would be inferred and returned accordingly. (:issue:`34923`)
670+
671+
.. ipython:: python
672+
673+
s = pd.Series([1,2,3])
674+
df = pd.DataFrame({'A': [s, s, s, s], 'B': 1})
675+
676+
*Previous behavior*:
677+
678+
.. code-block:: ipython
679+
680+
In [3]: df.explode("A").dtypes
681+
Out[3]:
682+
A object
683+
B int64
684+
dtype: object
685+
686+
*New behavior*:
687+
688+
.. code-block:: ipython
689+
690+
In [3]: df.explode("A").dtypes
691+
Out[3]:
692+
A int64
693+
B int64
694+
dtype: object
695+
664696
.. _whatsnew_110.api.other:
665697

666698
Other API changes

pandas/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3838,7 +3838,7 @@ def explode(self, ignore_index: bool = False) -> "Series":
38383838
else:
38393839
index = self.index.repeat(counts)
38403840

3841-
result = self._constructor(values, index=index, name=self.name)
3841+
result = self._constructor(values, index=index, name=self.name).infer_objects()
38423842

38433843
return result
38443844

pandas/tests/frame/methods/test_explode.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_basic():
2525
expected = pd.DataFrame(
2626
{
2727
"A": pd.Series(
28-
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object
28+
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=np.float64
2929
),
3030
"B": 1,
3131
}
@@ -55,7 +55,7 @@ def test_multi_index_rows():
5555
("b", 2),
5656
]
5757
),
58-
dtype=object,
58+
dtype=np.float64,
5959
),
6060
"B": 1,
6161
}
@@ -74,7 +74,7 @@ def test_multi_index_columns():
7474
("A", 1): pd.Series(
7575
[0, 1, 2, np.nan, np.nan, 3, 4],
7676
index=pd.Index([0, 0, 0, 1, 2, 3, 3]),
77-
dtype=object,
77+
dtype=np.float64,
7878
),
7979
("A", 2): 1,
8080
}
@@ -93,7 +93,7 @@ def test_usecase():
9393
expected = pd.DataFrame(
9494
{
9595
"A": [11, 11, 11, 11, 11, 22, 22, 22],
96-
"B": np.array([0, 1, 2, 3, 4, 0, 1, 2], dtype=object),
96+
"B": np.array([0, 1, 2, 3, 4, 0, 1, 2], dtype=np.int64),
9797
"C": [10, 10, 10, 10, 10, 20, 20, 20],
9898
},
9999
columns=list("ABC"),
@@ -160,7 +160,22 @@ def test_duplicate_index(input_dict, input_index, expected_dict, expected_index)
160160
# GH 28005
161161
df = pd.DataFrame(input_dict, index=input_index)
162162
result = df.explode("col1")
163-
expected = pd.DataFrame(expected_dict, index=expected_index, dtype=object)
163+
expected = pd.DataFrame(expected_dict, index=expected_index, dtype=np.int64)
164+
tm.assert_frame_equal(result, expected)
165+
166+
167+
def test_inferred_dtype():
168+
# GH 34923
169+
s = pd.Series([1, None, 3])
170+
df = pd.DataFrame({'A': [s, s], "B": 1})
171+
result = df.explode("A")
172+
expected = pd.DataFrame(
173+
{
174+
"A": np.array([1, np.nan, 3, 1, np.nan, 3], dtype=np.float64),
175+
"B": np.array([1, 1, 1, 1, 1, 1], dtype=np.int64)
176+
},
177+
index=[0, 0, 0, 1, 1, 1]
178+
)
164179
tm.assert_frame_equal(result, expected)
165180

166181

pandas/tests/series/methods/test_explode.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
def test_basic():
99
s = pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd"), name="foo")
10-
result = s. explode()
10+
result = s.explode()
1111
expected = pd.Series(
12-
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object, name="foo"
12+
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=np.float64, name="foo"
1313
)
1414
tm.assert_series_equal(result, expected)
1515

@@ -54,7 +54,7 @@ def test_multi_index():
5454
names=["foo", "bar"],
5555
)
5656
expected = pd.Series(
57-
[0, 1, 2, np.nan, np.nan, 3, 4], index=index, dtype=object, name="foo"
57+
[0, 1, 2, np.nan, np.nan, 3, 4], index=index, dtype=np.float64, name="foo"
5858
)
5959
tm.assert_series_equal(result, expected)
6060

@@ -116,14 +116,14 @@ def test_duplicate_index():
116116
# GH 28005
117117
s = pd.Series([[1, 2], [3, 4]], index=[0, 0])
118118
result = s.explode()
119-
expected = pd.Series([1, 2, 3, 4], index=[0, 0, 0, 0], dtype=object)
119+
expected = pd.Series([1, 2, 3, 4], index=[0, 0, 0, 0], dtype=np.int64)
120120
tm.assert_series_equal(result, expected)
121121

122122

123123
def test_ignore_index():
124124
# GH 34932
125125
s = pd.Series([[1, 2], [3, 4]])
126126
result = s.explode(ignore_index=True)
127-
expected = pd.Series([1, 2, 3, 4], index=[0, 1, 2, 3], dtype=object)
127+
expected = pd.Series([1, 2, 3, 4], index=[0, 1, 2, 3], dtype=np.int64)
128128
tm.assert_series_equal(result, expected)
129129

0 commit comments

Comments
 (0)