Skip to content

Commit 46bd88f

Browse files
rohanjain101Rohan Jainmroeschke
authored
preserve index in list accessor (#58438)
* preserve index in list accessor * gh reference * explode fix * cleanup * improve test * Update v3.0.0.rst Co-authored-by: Matthew Roeschke <[email protected]> * f --------- Co-authored-by: Rohan Jain <[email protected]> Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 7cdee7a commit 46bd88f

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ Other
483483
- Bug in :meth:`Series.rank` that doesn't preserve missing values for nullable integers when ``na_option='keep'``. (:issue:`56976`)
484484
- Bug in :meth:`Series.replace` and :meth:`DataFrame.replace` inconsistently replacing matching instances when ``regex=True`` and missing values are present. (:issue:`56599`)
485485
- Bug in Dataframe Interchange Protocol implementation was returning incorrect results for data buffers' associated dtype, for string and datetime columns (:issue:`54781`)
486+
- Bug in ``Series.list`` methods not preserving the original :class:`Index`. (:issue:`58425`)
486487

487488
.. ***DO NOT USE THIS SECTION***
488489

pandas/core/arrays/arrow/accessors.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def len(self) -> Series:
110110
from pandas import Series
111111

112112
value_lengths = pc.list_value_length(self._pa_array)
113-
return Series(value_lengths, dtype=ArrowDtype(value_lengths.type))
113+
return Series(
114+
value_lengths, dtype=ArrowDtype(value_lengths.type), index=self._data.index
115+
)
114116

115117
def __getitem__(self, key: int | slice) -> Series:
116118
"""
@@ -149,7 +151,9 @@ def __getitem__(self, key: int | slice) -> Series:
149151
# if key < 0:
150152
# key = pc.add(key, pc.list_value_length(self._pa_array))
151153
element = pc.list_element(self._pa_array, key)
152-
return Series(element, dtype=ArrowDtype(element.type))
154+
return Series(
155+
element, dtype=ArrowDtype(element.type), index=self._data.index
156+
)
153157
elif isinstance(key, slice):
154158
if pa_version_under11p0:
155159
raise NotImplementedError(
@@ -167,7 +171,7 @@ def __getitem__(self, key: int | slice) -> Series:
167171
if step is None:
168172
step = 1
169173
sliced = pc.list_slice(self._pa_array, start, stop, step)
170-
return Series(sliced, dtype=ArrowDtype(sliced.type))
174+
return Series(sliced, dtype=ArrowDtype(sliced.type), index=self._data.index)
171175
else:
172176
raise ValueError(f"key must be an int or slice, got {type(key).__name__}")
173177

@@ -195,15 +199,17 @@ def flatten(self) -> Series:
195199
... )
196200
>>> s.list.flatten()
197201
0 1
198-
1 2
199-
2 3
200-
3 3
202+
0 2
203+
0 3
204+
1 3
201205
dtype: int64[pyarrow]
202206
"""
203207
from pandas import Series
204208

205-
flattened = pc.list_flatten(self._pa_array)
206-
return Series(flattened, dtype=ArrowDtype(flattened.type))
209+
counts = pa.compute.list_value_length(self._pa_array)
210+
flattened = pa.compute.list_flatten(self._pa_array)
211+
index = self._data.index.repeat(counts.fill_null(pa.scalar(0, counts.type)))
212+
return Series(flattened, dtype=ArrowDtype(flattened.type), index=index)
207213

208214

209215
class StructAccessor(ArrowAccessor):

pandas/tests/series/accessors/test_list_accessor.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,23 @@ def test_list_getitem(list_dtype):
3131
tm.assert_series_equal(actual, expected)
3232

3333

34+
def test_list_getitem_index():
35+
# GH 58425
36+
ser = Series(
37+
[[1, 2, 3], [4, None, 5], None],
38+
dtype=ArrowDtype(pa.list_(pa.int64())),
39+
index=[1, 3, 7],
40+
)
41+
actual = ser.list[1]
42+
expected = Series([2, None, None], dtype="int64[pyarrow]", index=[1, 3, 7])
43+
tm.assert_series_equal(actual, expected)
44+
45+
3446
def test_list_getitem_slice():
3547
ser = Series(
3648
[[1, 2, 3], [4, None, 5], None],
3749
dtype=ArrowDtype(pa.list_(pa.int64())),
50+
index=[1, 3, 7],
3851
)
3952
if pa_version_under11p0:
4053
with pytest.raises(
@@ -44,7 +57,9 @@ def test_list_getitem_slice():
4457
else:
4558
actual = ser.list[1:None:None]
4659
expected = Series(
47-
[[2, 3], [None, 5], None], dtype=ArrowDtype(pa.list_(pa.int64()))
60+
[[2, 3], [None, 5], None],
61+
dtype=ArrowDtype(pa.list_(pa.int64())),
62+
index=[1, 3, 7],
4863
)
4964
tm.assert_series_equal(actual, expected)
5065

@@ -61,11 +76,15 @@ def test_list_len():
6176

6277
def test_list_flatten():
6378
ser = Series(
64-
[[1, 2, 3], [4, None], None],
79+
[[1, 2, 3], None, [4, None], [], [7, 8]],
6580
dtype=ArrowDtype(pa.list_(pa.int64())),
6681
)
6782
actual = ser.list.flatten()
68-
expected = Series([1, 2, 3, 4, None], dtype=ArrowDtype(pa.int64()))
83+
expected = Series(
84+
[1, 2, 3, 4, None, 7, 8],
85+
dtype=ArrowDtype(pa.int64()),
86+
index=[0, 0, 0, 2, 2, 4, 4],
87+
)
6988
tm.assert_series_equal(actual, expected)
7089

7190

0 commit comments

Comments
 (0)