Skip to content

Commit 7ca878e

Browse files
committed
API: add _to_safe_for_reshape to allow safe insert/append with embedded CategoricalIndexes
Signed-off-by: Jeff Reback <[email protected]>
1 parent d998337 commit 7ca878e

File tree

5 files changed

+46
-27
lines changed

5 files changed

+46
-27
lines changed

doc/source/whatsnew/v0.17.1.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ Bug Fixes
8787

8888
- Bug in list-like indexing with a mixed-integer Index (:issue:`11320`)
8989

90-
90+
- Bug in ``pivot_table`` with ``margins=True`` when indexes are of ``Categorical`` dtype (:issue:`10993`)
9191
- Bug in ``DataFrame.plot`` cannot use hex strings colors (:issue:`10299`)
9292

9393

pandas/core/index.py

+12
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,10 @@ def astype(self, dtype):
627627
return Index(self.values.astype(dtype), name=self.name,
628628
dtype=dtype)
629629

630+
def _to_safe_for_reshape(self):
631+
""" convert to object if we are a categorical """
632+
return self
633+
630634
def to_datetime(self, dayfirst=False):
631635
"""
632636
For an Index containing strings or datetime.datetime objects, attempt
@@ -3190,6 +3194,10 @@ def duplicated(self, keep='first'):
31903194
from pandas.hashtable import duplicated_int64
31913195
return duplicated_int64(self.codes.astype('i8'), keep)
31923196

3197+
def _to_safe_for_reshape(self):
3198+
""" convert to object if we are a categorical """
3199+
return self.astype('object')
3200+
31933201
def get_loc(self, key, method=None):
31943202
"""
31953203
Get integer location for requested label
@@ -4529,6 +4537,10 @@ def format(self, space=2, sparsify=None, adjoin=True, names=False,
45294537
else:
45304538
return result_levels
45314539

4540+
def _to_safe_for_reshape(self):
4541+
""" convert to object if we are a categorical """
4542+
return self.set_levels([ i._to_safe_for_reshape() for i in self.levels ])
4543+
45324544
def to_hierarchical(self, n_repeat, n_shuffle=1):
45334545
"""
45344546
Return a MultiIndex reshaped to conform to the

pandas/core/internals.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -3427,6 +3427,9 @@ def insert(self, loc, item, value, allow_duplicates=False):
34273427
if not isinstance(loc, int):
34283428
raise TypeError("loc must be int")
34293429

3430+
# insert to the axis; this could possibly raise a TypeError
3431+
new_axis = self.items.insert(loc, item)
3432+
34303433
block = make_block(values=value,
34313434
ndim=self.ndim,
34323435
placement=slice(loc, loc+1))
@@ -3449,8 +3452,7 @@ def insert(self, loc, item, value, allow_duplicates=False):
34493452
self._blklocs = np.insert(self._blklocs, loc, 0)
34503453
self._blknos = np.insert(self._blknos, loc, len(self.blocks))
34513454

3452-
self.axes[0] = self.items.insert(loc, item)
3453-
3455+
self.axes[0] = new_axis
34543456
self.blocks += (block,)
34553457
self._shape = None
34563458

pandas/tools/pivot.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,6 @@ def _add_margins(table, data, values, rows, cols, aggfunc):
159159

160160
grand_margin = _compute_grand_margin(data, values, aggfunc)
161161

162-
# categorical index or columns will fail below when 'All' is added
163-
# here we'll convert all categorical indices to object
164-
def convert_categorical(ind):
165-
_convert = lambda ind: (ind.astype('object')
166-
if ind.dtype.name == 'category' else ind)
167-
if isinstance(ind, MultiIndex):
168-
return ind.set_levels([_convert(lev) for lev in ind.levels])
169-
else:
170-
return _convert(ind)
171-
172-
table.index = convert_categorical(table.index)
173-
if hasattr(table, 'columns'):
174-
table.columns = convert_categorical(table.columns)
175-
176162
if not values and isinstance(table, Series):
177163
# If there are no values and the table is a series, then there is only
178164
# one column in the data. Compute grand margin and return it.
@@ -203,7 +189,13 @@ def convert_categorical(ind):
203189
margin_dummy = DataFrame(row_margin, columns=[key]).T
204190

205191
row_names = result.index.names
206-
result = result.append(margin_dummy)
192+
try:
193+
result = result.append(margin_dummy)
194+
except TypeError:
195+
196+
# we cannot reshape, so coerce the axis
197+
result.index = result.index._to_safe_for_reshape()
198+
result = result.append(margin_dummy)
207199
result.index.names = row_names
208200

209201
return result
@@ -232,6 +224,7 @@ def _compute_grand_margin(data, values, aggfunc):
232224

233225

234226
def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
227+
235228
if len(cols) > 0:
236229
# need to "interleave" the margins
237230
table_pieces = []
@@ -249,7 +242,13 @@ def _all_key(key):
249242

250243
# we are going to mutate this, so need to copy!
251244
piece = piece.copy()
252-
piece[all_key] = margin[key]
245+
try:
246+
piece[all_key] = margin[key]
247+
except TypeError:
248+
249+
# we cannot reshape, so coerce the axis
250+
piece.set_axis(cat_axis, piece._get_axis(cat_axis)._to_safe_for_reshape())
251+
piece[all_key] = margin[key]
253252

254253
table_pieces.append(piece)
255254
margin_keys.append(all_key)

pandas/tools/tests/test_pivot.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -721,17 +721,23 @@ def test_crosstab_dropna(self):
721721

722722
def test_categorical_margins(self):
723723
# GH 10989
724-
data = pd.DataFrame({'x': np.arange(8),
725-
'y': np.arange(8) // 4,
726-
'z': np.arange(8) % 2})
724+
df = pd.DataFrame({'x': np.arange(8),
725+
'y': np.arange(8) // 4,
726+
'z': np.arange(8) % 2})
727+
728+
expected = pd.DataFrame([[1.0, 2.0, 1.5],[5, 6, 5.5],[3, 4, 3.5]])
729+
expected.index = Index([0,1,'All'],name='y')
730+
expected.columns = Index([0,1,'All'],name='z')
731+
732+
data = df.copy()
733+
table = data.pivot_table('x', 'y', 'z', margins=True)
734+
tm.assert_frame_equal(table, expected)
735+
736+
data = df.copy()
727737
data.y = data.y.astype('category')
728738
data.z = data.z.astype('category')
729739
table = data.pivot_table('x', 'y', 'z', margins=True)
730-
assert_equal(table.values, [[1, 2, 1.5],
731-
[5, 6, 5.5],
732-
[3, 4, 3.5]])
733-
734-
740+
tm.assert_frame_equal(table, expected)
735741

736742
if __name__ == '__main__':
737743
import nose

0 commit comments

Comments
 (0)