|
12 | 12 |
|
13 | 13 | from pandas import (
|
14 | 14 | Categorical,
|
| 15 | + CategoricalIndex, |
15 | 16 | DataFrame,
|
16 | 17 | DatetimeIndex,
|
17 | 18 | Index,
|
18 | 19 | IntervalIndex,
|
| 20 | + MultiIndex, |
19 | 21 | Series,
|
20 | 22 | Timestamp,
|
21 | 23 | cut,
|
@@ -171,21 +173,6 @@ def test_assign_columns(self, float_frame):
|
171 | 173 | tm.assert_series_equal(float_frame["C"], df["baz"], check_names=False)
|
172 | 174 | tm.assert_series_equal(float_frame["hi"], df["foo2"], check_names=False)
|
173 | 175 |
|
174 |
| - def test_set_index_preserve_categorical_dtype(self): |
175 |
| - # GH13743, GH13854 |
176 |
| - df = DataFrame( |
177 |
| - { |
178 |
| - "A": [1, 2, 1, 1, 2], |
179 |
| - "B": [10, 16, 22, 28, 34], |
180 |
| - "C1": Categorical(list("abaab"), categories=list("bac"), ordered=False), |
181 |
| - "C2": Categorical(list("abaab"), categories=list("bac"), ordered=True), |
182 |
| - } |
183 |
| - ) |
184 |
| - for cols in ["C1", "C2", ["A", "C1"], ["A", "C2"], ["C1", "C2"]]: |
185 |
| - result = df.set_index(cols).reset_index() |
186 |
| - result = result.reindex(columns=df.columns) |
187 |
| - tm.assert_frame_equal(result, df) |
188 |
| - |
189 | 176 | def test_rename_signature(self):
|
190 | 177 | sig = inspect.signature(DataFrame.rename)
|
191 | 178 | parameters = set(sig.parameters)
|
@@ -266,3 +253,47 @@ def test_set_reset_index(self):
|
266 | 253 | df = df.set_index("B")
|
267 | 254 |
|
268 | 255 | df = df.reset_index()
|
| 256 | + |
| 257 | + |
| 258 | +class TestCategoricalIndex: |
| 259 | + def test_set_index_preserve_categorical_dtype(self): |
| 260 | + # GH13743, GH13854 |
| 261 | + df = DataFrame( |
| 262 | + { |
| 263 | + "A": [1, 2, 1, 1, 2], |
| 264 | + "B": [10, 16, 22, 28, 34], |
| 265 | + "C1": Categorical(list("abaab"), categories=list("bac"), ordered=False), |
| 266 | + "C2": Categorical(list("abaab"), categories=list("bac"), ordered=True), |
| 267 | + } |
| 268 | + ) |
| 269 | + for cols in ["C1", "C2", ["A", "C1"], ["A", "C2"], ["C1", "C2"]]: |
| 270 | + result = df.set_index(cols).reset_index() |
| 271 | + result = result.reindex(columns=df.columns) |
| 272 | + tm.assert_frame_equal(result, df) |
| 273 | + |
| 274 | + @pytest.mark.parametrize( |
| 275 | + "codes", ([[0, 0, 1, 1], [0, 1, 0, 1]], [[0, 0, -1, 1], [0, 1, 0, 1]]) |
| 276 | + ) |
| 277 | + def test_reindexing_with_missing_values(self, codes): |
| 278 | + # GH 24206 |
| 279 | + |
| 280 | + index = MultiIndex( |
| 281 | + [CategoricalIndex(["A", "B"]), CategoricalIndex(["a", "b"])], codes |
| 282 | + ) |
| 283 | + data = {"col": range(len(index))} |
| 284 | + df = DataFrame(data=data, index=index) |
| 285 | + |
| 286 | + expected = DataFrame( |
| 287 | + { |
| 288 | + "level_0": Categorical.from_codes(codes[0], categories=["A", "B"]), |
| 289 | + "level_1": Categorical.from_codes(codes[1], categories=["a", "b"]), |
| 290 | + "col": range(4), |
| 291 | + } |
| 292 | + ) |
| 293 | + |
| 294 | + res = df.reset_index() |
| 295 | + tm.assert_frame_equal(res, expected) |
| 296 | + |
| 297 | + # roundtrip |
| 298 | + res = expected.set_index(["level_0", "level_1"]).reset_index() |
| 299 | + tm.assert_frame_equal(res, expected) |
0 commit comments