Skip to content

Commit 61b14b2

Browse files
sinhrksjreback
authored andcommitted
COMPAT: Categorical Subclassing
xref pandas-dev#8640 Author: sinhrks <[email protected]> Closes pandas-dev#13827 from sinhrks/categorical_subclass and squashes the following commits: 13c456c [sinhrks] COMPAT: Categorical Subclassing
1 parent 8ec7406 commit 61b14b2

File tree

3 files changed

+71
-30
lines changed

3 files changed

+71
-30
lines changed

pandas/core/categorical.py

+33-29
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,16 @@ def __init__(self, values, categories=None, ordered=False,
328328
self._categories = categories
329329
self._codes = _coerce_indexer_dtype(codes, categories)
330330

331+
@property
332+
def _constructor(self):
333+
return Categorical
334+
331335
def copy(self):
332336
""" Copy constructor. """
333-
return Categorical(values=self._codes.copy(),
334-
categories=self.categories, ordered=self.ordered,
335-
fastpath=True)
337+
return self._constructor(values=self._codes.copy(),
338+
categories=self.categories,
339+
ordered=self.ordered,
340+
fastpath=True)
336341

337342
def astype(self, dtype, copy=True):
338343
"""
@@ -414,7 +419,7 @@ def from_array(cls, data, **kwargs):
414419
Can be an Index or array-like. The categories are assumed to be
415420
the unique values of `data`.
416421
"""
417-
return Categorical(data, **kwargs)
422+
return cls(data, **kwargs)
418423

419424
@classmethod
420425
def from_codes(cls, codes, categories, ordered=False, name=None):
@@ -458,8 +463,8 @@ def from_codes(cls, codes, categories, ordered=False, name=None):
458463
raise ValueError("codes need to be between -1 and "
459464
"len(categories)-1")
460465

461-
return Categorical(codes, categories=categories, ordered=ordered,
462-
fastpath=True)
466+
return cls(codes, categories=categories, ordered=ordered,
467+
fastpath=True)
463468

464469
_codes = None
465470

@@ -916,9 +921,9 @@ def map(self, mapper):
916921
"""
917922
new_categories = self.categories.map(mapper)
918923
try:
919-
return Categorical.from_codes(self._codes.copy(),
920-
categories=new_categories,
921-
ordered=self.ordered)
924+
return self.from_codes(self._codes.copy(),
925+
categories=new_categories,
926+
ordered=self.ordered)
922927
except ValueError:
923928
return np.take(new_categories, self._codes)
924929

@@ -968,8 +973,8 @@ def shift(self, periods):
968973
else:
969974
codes[periods:] = -1
970975

971-
return Categorical.from_codes(codes, categories=self.categories,
972-
ordered=self.ordered)
976+
return self.from_codes(codes, categories=self.categories,
977+
ordered=self.ordered)
973978

974979
def __array__(self, dtype=None):
975980
"""
@@ -1159,8 +1164,8 @@ def value_counts(self, dropna=True):
11591164
count = bincount(np.where(mask, code, ncat))
11601165
ix = np.append(ix, -1)
11611166

1162-
ix = Categorical(ix, categories=cat, ordered=obj.ordered,
1163-
fastpath=True)
1167+
ix = self._constructor(ix, categories=cat, ordered=obj.ordered,
1168+
fastpath=True)
11641169

11651170
return Series(count, index=CategoricalIndex(ix), dtype='int64')
11661171

@@ -1313,8 +1318,8 @@ def sort_values(self, inplace=False, ascending=True, na_position='last'):
13131318
self._codes = codes
13141319
return
13151320
else:
1316-
return Categorical(values=codes, categories=self.categories,
1317-
ordered=self.ordered, fastpath=True)
1321+
return self._constructor(values=codes, categories=self.categories,
1322+
ordered=self.ordered, fastpath=True)
13181323

13191324
def order(self, inplace=False, ascending=True, na_position='last'):
13201325
"""
@@ -1441,8 +1446,8 @@ def fillna(self, value=None, method=None, limit=None):
14411446
values = values.copy()
14421447
values[mask] = self.categories.get_loc(value)
14431448

1444-
return Categorical(values, categories=self.categories,
1445-
ordered=self.ordered, fastpath=True)
1449+
return self._constructor(values, categories=self.categories,
1450+
ordered=self.ordered, fastpath=True)
14461451

14471452
def take_nd(self, indexer, allow_fill=True, fill_value=None):
14481453
""" Take the codes by the indexer, fill with the fill_value.
@@ -1455,8 +1460,8 @@ def take_nd(self, indexer, allow_fill=True, fill_value=None):
14551460
assert isnull(fill_value)
14561461

14571462
codes = take_1d(self._codes, indexer, allow_fill=True, fill_value=-1)
1458-
result = Categorical(codes, categories=self.categories,
1459-
ordered=self.ordered, fastpath=True)
1463+
result = self._constructor(codes, categories=self.categories,
1464+
ordered=self.ordered, fastpath=True)
14601465
return result
14611466

14621467
take = take_nd
@@ -1476,8 +1481,8 @@ def _slice(self, slicer):
14761481
slicer = slicer[1]
14771482

14781483
_codes = self._codes[slicer]
1479-
return Categorical(values=_codes, categories=self.categories,
1480-
ordered=self.ordered, fastpath=True)
1484+
return self._constructor(values=_codes, categories=self.categories,
1485+
ordered=self.ordered, fastpath=True)
14811486

14821487
def __len__(self):
14831488
"""The length of this Categorical."""
@@ -1588,10 +1593,9 @@ def __getitem__(self, key):
15881593
else:
15891594
return self.categories[i]
15901595
else:
1591-
return Categorical(values=self._codes[key],
1592-
categories=self.categories,
1593-
ordered=self.ordered,
1594-
fastpath=True)
1596+
return self._constructor(values=self._codes[key],
1597+
categories=self.categories,
1598+
ordered=self.ordered, fastpath=True)
15951599

15961600
def __setitem__(self, key, value):
15971601
""" Item assignment.
@@ -1742,8 +1746,8 @@ def mode(self):
17421746
import pandas.hashtable as htable
17431747
good = self._codes != -1
17441748
values = sorted(htable.mode_int64(_ensure_int64(self._codes[good])))
1745-
result = Categorical(values=values, categories=self.categories,
1746-
ordered=self.ordered, fastpath=True)
1749+
result = self._constructor(values=values, categories=self.categories,
1750+
ordered=self.ordered, fastpath=True)
17471751
return result
17481752

17491753
def unique(self):
@@ -1837,8 +1841,8 @@ def repeat(self, repeats, *args, **kwargs):
18371841
"""
18381842
nv.validate_repeat(args, kwargs)
18391843
codes = self._codes.repeat(repeats)
1840-
return Categorical(values=codes, categories=self.categories,
1841-
ordered=self.ordered, fastpath=True)
1844+
return self._constructor(values=codes, categories=self.categories,
1845+
ordered=self.ordered, fastpath=True)
18421846

18431847
# The Series.cat accessor
18441848

pandas/tests/test_categorical.py

+30
Original file line numberDiff line numberDiff line change
@@ -4415,6 +4415,36 @@ def test_concat_categorical(self):
44154415
tm.assert_frame_equal(df_expected, df_concat)
44164416

44174417

4418+
class TestCategoricalSubclassing(tm.TestCase):
4419+
4420+
_multiprocess_can_split_ = True
4421+
4422+
def test_constructor(self):
4423+
sc = tm.SubclassedCategorical(['a', 'b', 'c'])
4424+
self.assertIsInstance(sc, tm.SubclassedCategorical)
4425+
tm.assert_categorical_equal(sc, Categorical(['a', 'b', 'c']))
4426+
4427+
def test_from_array(self):
4428+
sc = tm.SubclassedCategorical.from_codes([1, 0, 2], ['a', 'b', 'c'])
4429+
self.assertIsInstance(sc, tm.SubclassedCategorical)
4430+
exp = Categorical.from_codes([1, 0, 2], ['a', 'b', 'c'])
4431+
tm.assert_categorical_equal(sc, exp)
4432+
4433+
def test_map(self):
4434+
sc = tm.SubclassedCategorical(['a', 'b', 'c'])
4435+
res = sc.map(lambda x: x.upper())
4436+
self.assertIsInstance(res, tm.SubclassedCategorical)
4437+
exp = Categorical(['A', 'B', 'C'])
4438+
tm.assert_categorical_equal(res, exp)
4439+
4440+
def test_map(self):
4441+
sc = tm.SubclassedCategorical(['a', 'b', 'c'])
4442+
res = sc.map(lambda x: x.upper())
4443+
self.assertIsInstance(res, tm.SubclassedCategorical)
4444+
exp = Categorical(['A', 'B', 'C'])
4445+
tm.assert_categorical_equal(res, exp)
4446+
4447+
44184448
if __name__ == '__main__':
44194449
import nose
44204450
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],

pandas/util/testing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
from pandas.computation import expressions as expr
4545

46-
from pandas import (bdate_range, CategoricalIndex, DatetimeIndex,
46+
from pandas import (bdate_range, CategoricalIndex, Categorical, DatetimeIndex,
4747
TimedeltaIndex, PeriodIndex, RangeIndex, Index, MultiIndex,
4848
Series, DataFrame, Panel, Panel4D)
4949
from pandas.util.decorators import deprecate
@@ -2670,6 +2670,13 @@ def _constructor_sliced(self):
26702670
return SubclassedSparseSeries
26712671

26722672

2673+
class SubclassedCategorical(Categorical):
2674+
2675+
@property
2676+
def _constructor(self):
2677+
return SubclassedCategorical
2678+
2679+
26732680
@contextmanager
26742681
def patch(ob, attr, value):
26752682
"""Temporarily patch an attribute of an object.

0 commit comments

Comments
 (0)