diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 3ac3c8eef0a10..afe7f8775b1e9 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -340,7 +340,7 @@ def drop(self, labels, axis=0, level=None): dropped : type of caller """ axis_name = self._get_axis_name(axis) - axis = self._get_axis(axis) + axis, axis_ = self._get_axis(axis), axis if axis.is_unique: if level is not None: @@ -349,8 +349,13 @@ def drop(self, labels, axis=0, level=None): new_axis = axis.drop(labels, level=level) else: new_axis = axis.drop(labels) + dropped = self.reindex(**{axis_name: new_axis}) + try: + dropped.axes[axis_].names = axis.names + except AttributeError: + pass + return dropped - return self.reindex(**{axis_name: new_axis}) else: if level is not None: if not isinstance(axis, MultiIndex): diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index 9329bb1da2b07..69c8bdc60be13 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -5103,6 +5103,16 @@ def test_corrwith_series(self): assert_series_equal(result, expected) + def test_drop_names(self): + df = DataFrame([[1, 2, 3],[3, 4, 5],[5, 6, 7]], index=['a', 'b', 'c'], columns=['d', 'e', 'f']) + df.index.name, df.columns.name = 'first', 'second' + df_dropped_b = df.drop('b') + df_dropped_e = df.drop('e', axis=1) + self.assert_(df_dropped_b.index.name == 'first') + self.assert_(df_dropped_e.index.name == 'first') + self.assert_(df_dropped_b.columns.name == 'second') + self.assert_(df_dropped_e.columns.name == 'second') + def test_dropEmptyRows(self): N = len(self.frame.index) mat = randn(N) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 97059d9aaf9f5..6ce6a6d02c427 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -200,10 +200,12 @@ def assert_frame_equal(left, right, check_dtype=True, assert(type(left.index) == type(right.index)) assert(left.index.dtype == right.index.dtype) assert(left.index.inferred_type == right.index.inferred_type) + assert(left.index.names == right.index.names) if check_column_type: assert(type(left.columns) == type(right.columns)) assert(left.columns.dtype == right.columns.dtype) assert(left.columns.inferred_type == right.columns.inferred_type) + assert(left.columns.names == right.columns.names) def assert_panel_equal(left, right,