From 153077f4c74562758263562826ac4158ae9ae948 Mon Sep 17 00:00:00 2001 From: Andy Hayden Date: Sat, 2 Mar 2013 13:44:45 +0000 Subject: [PATCH] keep name after DataFrame drop, add check_name to assert_dataframe_equal --- pandas/core/generic.py | 9 +++++++-- pandas/tests/test_frame.py | 10 ++++++++++ pandas/util/testing.py | 6 +++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 3ac3c8eef0a10..afe7f8775b1e9 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -340,7 +340,7 @@ def drop(self, labels, axis=0, level=None): dropped : type of caller """ axis_name = self._get_axis_name(axis) - axis = self._get_axis(axis) + axis, axis_ = self._get_axis(axis), axis if axis.is_unique: if level is not None: @@ -349,8 +349,13 @@ def drop(self, labels, axis=0, level=None): new_axis = axis.drop(labels, level=level) else: new_axis = axis.drop(labels) + dropped = self.reindex(**{axis_name: new_axis}) + try: + dropped.axes[axis_].names = axis.names + except AttributeError: + pass + return dropped - return self.reindex(**{axis_name: new_axis}) else: if level is not None: if not isinstance(axis, MultiIndex): diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index 9329bb1da2b07..69c8bdc60be13 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -5103,6 +5103,16 @@ def test_corrwith_series(self): assert_series_equal(result, expected) + def test_drop_names(self): + df = DataFrame([[1, 2, 3],[3, 4, 5],[5, 6, 7]], index=['a', 'b', 'c'], columns=['d', 'e', 'f']) + df.index.name, df.columns.name = 'first', 'second' + df_dropped_b = df.drop('b') + df_dropped_e = df.drop('e', axis=1) + self.assert_(df_dropped_b.index.name == 'first') + self.assert_(df_dropped_e.index.name == 'first') + self.assert_(df_dropped_b.columns.name == 'second') + self.assert_(df_dropped_e.columns.name == 'second') + def test_dropEmptyRows(self): N = len(self.frame.index) mat = randn(N) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 97059d9aaf9f5..fddbf405b3f26 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -178,7 +178,8 @@ def assert_frame_equal(left, right, check_dtype=True, check_index_type=False, check_column_type=False, check_frame_type=False, - check_less_precise=False): + check_less_precise=False, + check_names=False): if check_frame_type: assert(type(left) == type(right)) assert(isinstance(left, DataFrame)) @@ -204,6 +205,9 @@ def assert_frame_equal(left, right, check_dtype=True, assert(type(left.columns) == type(right.columns)) assert(left.columns.dtype == right.columns.dtype) assert(left.columns.inferred_type == right.columns.inferred_type) + if check_names: + assert(left.index.names == right.index.names) + assert(left.columns.names == right.columns.names) def assert_panel_equal(left, right,