-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Support EAs in Series.unstack #23284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
ced299f
3b63fcb
756dde9
90f84ef
942db1b
36a4450
ee330d6
2fcaf4d
4f46364
e9498a1
72b5a0d
f6b2050
4d679cb
ff7aba7
91587cb
49bdb50
cf8ed73
5902b5b
17d3002
a75806a
2397e89
8ed7c73
b23234c
29a6bb1
19b7cfa
254fe52
2d78d42
a9e6263
ca286f7
2f28638
967c674
f6aa4b9
32bc3de
56e5f2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -344,6 +344,7 @@ def _unstack_multiple(data, clocs, fill_value=None): | |
if isinstance(data, Series): | ||
dummy = data.copy() | ||
dummy.index = dummy_index | ||
|
||
unstacked = dummy.unstack('__placeholder__', fill_value=fill_value) | ||
new_levels = clevels | ||
new_names = cnames | ||
|
@@ -399,6 +400,8 @@ def unstack(obj, level, fill_value=None): | |
else: | ||
return obj.T.stack(dropna=False) | ||
else: | ||
if is_extension_array_dtype(obj.dtype): | ||
return unstack_extension_series(obj, level, fill_value) | ||
unstacker = _Unstacker(obj.values, obj.index, level=level, | ||
fill_value=fill_value, | ||
constructor=obj._constructor_expanddim) | ||
|
@@ -947,3 +950,22 @@ def make_axis_dummies(frame, axis='minor', transform=None): | |
values = values.take(labels, axis=0) | ||
|
||
return DataFrame(values, columns=items, index=frame.index) | ||
|
||
|
||
def unstack_extension_series(series, level, fill_value): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move this function up to around line 424? It looks like this file has all |
||
from pandas.core.reshape.concat import concat | ||
|
||
dummy_arr = np.arange(len(series)) | ||
# fill_value=-1, since we will do a series.values.take later | ||
result = _Unstacker(dummy_arr, series.index, | ||
level=level, fill_value=-1).get_result() | ||
|
||
out = [] | ||
values = series.values | ||
|
||
for col, indices in result.iteritems(): | ||
out.append(Series(values.take(indices.values, | ||
allow_fill=True, | ||
fill_value=fill_value), | ||
name=col, index=result.index)) | ||
return concat(out, axis='columns') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import itertools | ||
import pytest | ||
import numpy as np | ||
|
||
|
@@ -170,3 +171,40 @@ def test_merge(self, data, na_value): | |
[data[0], data[0], data[1], data[2], na_value], | ||
dtype=data.dtype)}) | ||
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']]) | ||
|
||
@pytest.mark.parametrize("index", [ | ||
pd.MultiIndex.from_product(([['A', 'B'], ['a', 'b']])), | ||
pd.MultiIndex.from_product(([['A', 'B'], ['a', 'b'], ['x', 'y', 'z']])), | ||
|
||
# non-uniform | ||
pd.MultiIndex.from_tuples([('A', 'a'), ('A', 'b'), ('B', 'b')]), | ||
|
||
# three levels, non-uniform | ||
pd.MultiIndex.from_product([('A', 'B'), ('a', 'b', 'c'), (0, 1, 2)]), | ||
pd.MultiIndex.from_tuples([ | ||
('A', 'a', 1), | ||
('A', 'b', 0), | ||
('A', 'a', 0), | ||
('B', 'a', 0), | ||
('B', 'c', 1), | ||
]), | ||
]) | ||
def test_unstack(self, data, index): | ||
data = data[:len(index)] | ||
ser = pd.Series(data, index=index) | ||
|
||
n = index.nlevels | ||
levels = list(range(n)) | ||
# [0, 1, 2] | ||
# -> [(0,), (1,), (2,) (0, 1), (1, 0)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you're correct. |
||
combinations = itertools.chain.from_iterable( | ||
itertools.permutations(levels, i) for i in range(1, n) | ||
) | ||
|
||
for level in combinations: | ||
result = ser.unstack(level=level) | ||
assert all(isinstance(result[col].values, type(data)) for col in result.columns) | ||
expected = ser.astype(object).unstack(level=level) | ||
result = result.astype(object) | ||
|
||
self.assert_frame_equal(result, expected) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,7 +102,10 @@ def copy(self, deep=False): | |
def astype(self, dtype, copy=True): | ||
if isinstance(dtype, type(self.dtype)): | ||
return type(self)(self._data, context=dtype.context) | ||
return super(DecimalArray, self).astype(dtype, copy) | ||
# need to replace decimal NA | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Series.equal doesn't consider |
||
result = np.asarray(self, dtype=dtype) | ||
result[self.isna()] = np.nan | ||
return result | ||
|
||
def __setitem__(self, key, value): | ||
if pd.api.types.is_list_like(value): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really? what does this change for Categorical?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously
Series[Categorical].unstack()
returnedDataFrame[object
].Now it'll be a
DataFrame[Categorical]
, i.e.unstack()
preserves the CategoricalDtype.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I forget. Previously, we went internally went
Categorical -> object -> Categorical
. Now we avoid the conversion to categorical.So the changes from 0.23.4 will be
Onces DatetimeTZ is an ExtensionArray, then we'll presumably preserve that as well. On 0.23.4, we convert to datetime64ns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, this might be need a larger note then