Skip to content

Commit c559662

Browse files
sinhrkschris-b1
authored andcommitted
ENH: add sort_categories argument to union_categoricals
1 parent 59f2557 commit c559662

File tree

3 files changed

+80
-29
lines changed

3 files changed

+80
-29
lines changed

doc/source/categorical.rst

+9-1
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ Unioning
656656
.. versionadded:: 0.19.0
657657

658658
If you want to combine categoricals that do not necessarily have
659-
the same categories, the `union_categorical` function will
659+
the same categories, the ``union_categoricals`` function will
660660
combine a list-like of categoricals. The new categories
661661
will be the union of the categories being combined.
662662

@@ -667,6 +667,14 @@ will be the union of the categories being combined.
667667
b = pd.Categorical(["a", "b"])
668668
union_categoricals([a, b])
669669
670+
By default, the resulting categories will be ordered as
671+
they appear in the data. If you want the categories to
672+
be lexsorted, use ``sort_categories=True`` argument.
673+
674+
.. ipython:: python
675+
676+
union_categoricals([a, b], sort_categories=True)
677+
670678
.. note::
671679

672680
In addition to the "easy" case of combining two categoricals of the same

pandas/tools/tests/test_concat.py

+36
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,42 @@ def test_union_categoricals_ordered(self):
989989
with tm.assertRaisesRegexp(TypeError, msg):
990990
union_categoricals([c1, c2])
991991

992+
def test_union_categoricals_sort(self):
993+
# GH 13763
994+
c1 = Categorical(['x', 'y', 'z'])
995+
c2 = Categorical(['a', 'b', 'c'])
996+
result = union_categoricals([c1, c2], sort_categories=True)
997+
expected = Categorical(['x', 'y', 'z', 'a', 'b', 'c'],
998+
categories=['a', 'b', 'c', 'x', 'y', 'z'])
999+
tm.assert_categorical_equal(result, expected)
1000+
1001+
# fastpath
1002+
c1 = Categorical(['a', 'b'], categories=['b', 'a', 'c'])
1003+
c2 = Categorical(['b', 'c'], categories=['b', 'a', 'c'])
1004+
result = union_categoricals([c1, c2], sort_categories=True)
1005+
expected = Categorical(['a', 'b', 'b', 'c'],
1006+
categories=['a', 'b', 'c'])
1007+
tm.assert_categorical_equal(result, expected)
1008+
1009+
c1 = Categorical(['x', np.nan])
1010+
c2 = Categorical([np.nan, 'b'])
1011+
result = union_categoricals([c1, c2], sort_categories=True)
1012+
expected = Categorical(['x', np.nan, np.nan, 'b'],
1013+
categories=['b', 'x'])
1014+
tm.assert_categorical_equal(result, expected)
1015+
1016+
c1 = Categorical([np.nan])
1017+
c2 = Categorical([np.nan])
1018+
result = union_categoricals([c1, c2], sort_categories=True)
1019+
expected = Categorical([np.nan, np.nan], categories=[])
1020+
tm.assert_categorical_equal(result, expected)
1021+
1022+
c1 = Categorical([])
1023+
c2 = Categorical([])
1024+
result = union_categoricals([c1, c2], sort_categories=True)
1025+
expected = Categorical([])
1026+
tm.assert_categorical_equal(result, expected)
1027+
9921028
def test_concat_bug_1719(self):
9931029
ts1 = tm.makeTimeSeries()
9941030
ts2 = tm.makeTimeSeries()[::2]

pandas/types/concat.py

+35-28
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,23 @@ def convert_categorical(x):
211211
return Categorical(concatted, rawcats)
212212

213213

214-
def union_categoricals(to_union):
214+
def union_categoricals(to_union, sort_categories=False):
215215
"""
216216
Combine list-like of Categoricals, unioning categories. All
217-
must have the same dtype, and none can be ordered.
217+
categories must have the same dtype.
218218
219219
.. versionadded:: 0.19.0
220220
221221
Parameters
222222
----------
223223
to_union : list-like of Categoricals
224+
sort_categories : boolean, default False
225+
If true, resulting categories will be lexsorted, otherwise
226+
they will be ordered as they appear in the data
224227
225228
Returns
226229
-------
227-
Categorical
228-
A single array, categories will be ordered as they
229-
appear in the list
230+
result : Categorical
230231
231232
Raises
232233
------
@@ -244,41 +245,47 @@ def union_categoricals(to_union):
244245

245246
first = to_union[0]
246247

247-
if not all(is_dtype_equal(c.categories.dtype, first.categories.dtype)
248-
for c in to_union):
248+
if not all(is_dtype_equal(other.categories.dtype, first.categories.dtype)
249+
for other in to_union[1:]):
249250
raise TypeError("dtype of categories must be the same")
250251

252+
ordered = False
251253
if all(first.is_dtype_equal(other) for other in to_union[1:]):
252-
return Categorical(np.concatenate([c.codes for c in to_union]),
253-
categories=first.categories, ordered=first.ordered,
254-
fastpath=True)
254+
# identical categories - fastpath
255+
categories = first.categories
256+
ordered = first.ordered
257+
new_codes = np.concatenate([c.codes for c in to_union])
258+
259+
if sort_categories:
260+
categories = categories.sort_values()
261+
indexer = first.categories.get_indexer(categories)
262+
new_codes = take_1d(indexer, new_codes, fill_value=-1)
255263
elif all(not c.ordered for c in to_union):
256-
# not ordered
257-
pass
264+
# different categories - union and recode
265+
cats = first.categories.append([c.categories for c in to_union[1:]])
266+
categories = Index(cats.unique())
267+
if sort_categories:
268+
categories = categories.sort_values()
269+
270+
new_codes = []
271+
for c in to_union:
272+
if len(c.categories) > 0:
273+
indexer = categories.get_indexer(c.categories)
274+
new_codes.append(take_1d(indexer, c.codes, fill_value=-1))
275+
else:
276+
# must be all NaN
277+
new_codes.append(c.codes)
278+
new_codes = np.concatenate(new_codes)
258279
else:
259-
# to show a proper error message
280+
# ordered - to show a proper error message
260281
if all(c.ordered for c in to_union):
261282
msg = ("to union ordered Categoricals, "
262283
"all categories must be the same")
263284
raise TypeError(msg)
264285
else:
265286
raise TypeError('Categorical.ordered must be the same')
266287

267-
cats = first.categories
268-
unique_cats = cats.append([c.categories for c in to_union[1:]]).unique()
269-
categories = Index(unique_cats)
270-
271-
new_codes = []
272-
for c in to_union:
273-
if len(c.categories) > 0:
274-
indexer = categories.get_indexer(c.categories)
275-
new_codes.append(take_1d(indexer, c.codes, fill_value=-1))
276-
else:
277-
# must be all NaN
278-
new_codes.append(c.codes)
279-
280-
new_codes = np.concatenate(new_codes)
281-
return Categorical(new_codes, categories=categories, ordered=False,
288+
return Categorical(new_codes, categories=categories, ordered=ordered,
282289
fastpath=True)
283290

284291

0 commit comments

Comments
 (0)