Skip to content

Commit 86a1719

Browse files
committed
BUG/API: .merge() and .join() on category dtype columns will now preserve the category dtype when possible
closes #10409
1 parent ec9bd44 commit 86a1719

File tree

8 files changed

+263
-63
lines changed

8 files changed

+263
-63
lines changed

asv_bench/benchmarks/join_merge.py

+24
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,30 @@ def time_i8merge(self):
257257
merge(self.left, self.right, how='outer')
258258

259259

260+
class MergeCategoricals(object):
261+
goal_time = 0.2
262+
263+
def setup(self):
264+
self.left_object = pd.DataFrame(
265+
{'X': np.random.choice(range(0, 10), size=(10000,)),
266+
'Y': np.random.choice(['one', 'two', 'three'], size=(10000,))})
267+
268+
self.right_object = pd.DataFrame(
269+
{'X': np.random.choice(range(0, 10), size=(10000,)),
270+
'Z': np.random.choice(['jjj', 'kkk', 'sss'], size=(10000,))})
271+
272+
self.left_cat = self.left_object.assign(
273+
Y=self.left_object['Y'].astype('category'))
274+
self.right_cat = self.right_object.assign(
275+
Z=self.right_object['Z'].astype('category'))
276+
277+
def time_merge_object(self):
278+
merge(self.left_object, self.right_object, on='X')
279+
280+
def time_merge_cat(self):
281+
merge(self.left_cat, self.right_cat, on='X')
282+
283+
260284
#----------------------------------------------------------------------
261285
# Ordered merge
262286

doc/source/whatsnew/v0.20.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ Other API Changes
371371
- ``pandas.api.types.is_datetime64_ns_dtype`` will now report ``True`` on a tz-aware dtype, similar to ``pandas.api.types.is_datetime64_any_dtype``
372372
- ``DataFrame.asof()`` will return a null filled ``Series`` instead the scalar ``NaN`` if a match is not found (:issue:`15118`)
373373
- The :func:`pd.read_gbq` method now stores ``INTEGER`` columns as ``dtype=object`` if they contain ``NULL`` values. Otherwise they are stored as ``int64``. This prevents precision lost for integers greather than 2**53. Furthermore ``FLOAT`` columns with values above 10**4 are no more casted to ``int64`` which also caused precision lost (:issue: `14064`, :issue:`14305`).
374+
- ``.merge()`` and ``.join()`` on ``category`` dtype columns will now preserve the category dtype when possible (:issue:`10409`)
374375

375376
.. _whatsnew_0200.deprecations:
376377

@@ -412,6 +413,7 @@ Performance Improvements
412413
- Improved performance of timeseries plotting with an irregular DatetimeIndex
413414
(or with ``compat_x=True``) (:issue:`15073`).
414415
- Improved performance of ``groupby().cummin()`` and ``groupby().cummax()`` (:issue:`15048`, :issue:`15109`)
416+
- Improved performance of merge/join on ``category`` columns (:issue:`10409`)
415417

416418
- When reading buffer object in ``read_sas()`` method without specified format, filepath string is inferred rather than buffer object.
417419

pandas/core/internals.py

+2
Original file line numberDiff line numberDiff line change
@@ -5224,6 +5224,8 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
52245224
# External code requested filling/upcasting, bool values must
52255225
# be upcasted to object to avoid being upcasted to numeric.
52265226
values = self.block.astype(np.object_).values
5227+
elif self.block.is_categorical:
5228+
values = self.block.values
52275229
else:
52285230
# No dtype upcasting is done here, it will be performed during
52295231
# concatenation itself.

pandas/tests/test_categorical.py

+3
Original file line numberDiff line numberDiff line change
@@ -4097,9 +4097,12 @@ def test_merge(self):
40974097
expected = df.copy()
40984098

40994099
# object-cat
4100+
# note that we propogate the category
4101+
# because we don't have any matching rows
41004102
cright = right.copy()
41014103
cright['d'] = cright['d'].astype('category')
41024104
result = pd.merge(left, cright, how='left', left_on='b', right_on='c')
4105+
expected['d'] = expected['d'].astype('category', categories=['null'])
41034106
tm.assert_frame_equal(result, expected)
41044107

41054108
# cat-object

pandas/tests/types/test_common.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,33 @@ def test_period_dtype(self):
3939

4040

4141
def test_dtype_equal():
42-
assert is_dtype_equal(np.int64, np.int64)
43-
assert not is_dtype_equal(np.int64, np.float64)
4442

45-
p1 = PeriodDtype('D')
46-
p2 = PeriodDtype('D')
47-
assert is_dtype_equal(p1, p2)
48-
assert not is_dtype_equal(np.int64, p1)
43+
dtypes = dict(dt_tz=pandas_dtype('datetime64[ns, US/Eastern]'),
44+
dt=pandas_dtype('datetime64[ns]'),
45+
td=pandas_dtype('timedelta64[ns]'),
46+
p=PeriodDtype('D'),
47+
i=np.int64,
48+
f=np.float64,
49+
o=np.object)
4950

50-
p3 = PeriodDtype('2D')
51-
assert not is_dtype_equal(p1, p3)
51+
# match equal to self, but not equal to other
52+
for name, dtype in dtypes.items():
53+
assert is_dtype_equal(dtype, dtype)
5254

53-
assert not DatetimeTZDtype.is_dtype(np.int64)
54-
assert not PeriodDtype.is_dtype(np.int64)
55+
for name2, dtype2 in dtypes.items():
56+
if name != name2:
57+
assert not is_dtype_equal(dtype, dtype2)
58+
59+
# we are strict on kind equality
60+
for dtype in [np.int8, np.int16, np.int32]:
61+
assert not is_dtype_equal(dtypes['i'], dtype)
62+
63+
for dtype in [np.float32]:
64+
assert not is_dtype_equal(dtypes['f'], dtype)
65+
66+
# strict w.r.t. PeriodDtype
67+
assert not is_dtype_equal(dtypes['p'], PeriodDtype('2D'))
68+
69+
# strict w.r.t. datetime64
70+
assert not is_dtype_equal(dtypes['dt_tz'],
71+
pandas_dtype('datetime64[ns, CET]'))

pandas/tools/merge.py

+65-21
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212

1313
import pandas as pd
1414
from pandas import (Categorical, Series, DataFrame,
15-
Index, MultiIndex, Timedelta)
15+
Index, MultiIndex, Timedelta, lib)
1616
from pandas.core.frame import _merge_doc
1717
from pandas.types.common import (is_datetime64tz_dtype,
1818
is_datetime64_dtype,
1919
needs_i8_conversion,
2020
is_int64_dtype,
21+
is_categorical_dtype,
2122
is_integer_dtype,
2223
is_float_dtype,
24+
is_numeric_dtype,
2325
is_integer,
2426
is_int_or_datetime_dtype,
2527
is_dtype_equal,
@@ -567,6 +569,10 @@ def __init__(self, left, right, how='inner', on=None,
567569
self.right_join_keys,
568570
self.join_names) = self._get_merge_keys()
569571

572+
# validate the merge keys dtypes. We may need to coerce
573+
# to avoid incompat dtypes
574+
self._maybe_coerce_merge_keys()
575+
570576
def get_result(self):
571577
if self.indicator:
572578
self.left, self.right = self._indicator_pre_merge(
@@ -757,26 +763,6 @@ def _get_join_info(self):
757763
join_index = join_index.astype(object)
758764
return join_index, left_indexer, right_indexer
759765

760-
def _get_merge_data(self):
761-
"""
762-
Handles overlapping column names etc.
763-
"""
764-
ldata, rdata = self.left._data, self.right._data
765-
lsuf, rsuf = self.suffixes
766-
767-
llabels, rlabels = items_overlap_with_suffix(
768-
ldata.items, lsuf, rdata.items, rsuf)
769-
770-
if not llabels.equals(ldata.items):
771-
ldata = ldata.copy(deep=False)
772-
ldata.set_axis(0, llabels)
773-
774-
if not rlabels.equals(rdata.items):
775-
rdata = rdata.copy(deep=False)
776-
rdata.set_axis(0, rlabels)
777-
778-
return ldata, rdata
779-
780766
def _get_merge_keys(self):
781767
"""
782768
Note: has side effects (copy/delete key columns)
@@ -888,6 +874,51 @@ def _get_merge_keys(self):
888874

889875
return left_keys, right_keys, join_names
890876

877+
def _maybe_coerce_merge_keys(self):
878+
# we have valid mergee's but we may have to further
879+
# coerce these if they are originally incompatible types
880+
#
881+
# for example if these are categorical, but are not dtype_equal
882+
# or if we have object and integer dtypes
883+
884+
for lk, rk, name in zip(self.left_join_keys,
885+
self.right_join_keys,
886+
self.join_names):
887+
if (len(lk) and not len(rk)) or (not len(lk) and len(rk)):
888+
continue
889+
890+
# if either left or right is a categorical
891+
# then the must match exactly in categories & ordered
892+
if is_categorical_dtype(lk) and is_categorical_dtype(rk):
893+
if lk.is_dtype_equal(rk):
894+
continue
895+
elif is_categorical_dtype(lk) or is_categorical_dtype(rk):
896+
pass
897+
898+
elif is_dtype_equal(lk.dtype, rk.dtype):
899+
continue
900+
901+
# if we are numeric, then allow differing
902+
# kinds to proceed, eg. int64 and int8
903+
# further if we are object, but we infer to
904+
# the same, then proceed
905+
if (is_numeric_dtype(lk) and is_numeric_dtype(rk)):
906+
if lk.dtype.kind == rk.dtype.kind:
907+
continue
908+
909+
# let's infer and see if we are ok
910+
if lib.infer_dtype(lk) == lib.infer_dtype(rk):
911+
continue
912+
913+
# Houston, we have a problem!
914+
# let's coerce to object
915+
if name in self.left.columns:
916+
self.left = self.left.assign(
917+
**{name: self.left[name].astype(object)})
918+
if name in self.right.columns:
919+
self.right = self.right.assign(
920+
**{name: self.right[name].astype(object)})
921+
891922
def _validate_specification(self):
892923
# Hm, any way to make this logic less complicated??
893924
if self.on is None and self.left_on is None and self.right_on is None:
@@ -939,9 +970,15 @@ def _get_join_indexers(left_keys, right_keys, sort=False, how='inner',
939970
940971
Parameters
941972
----------
973+
left_keys: ndarray, Index, Series
974+
right_keys: ndarray, Index, Series
975+
sort: boolean, default False
976+
how: string {'inner', 'outer', 'left', 'right'}, default 'inner'
942977
943978
Returns
944979
-------
980+
tuple of (left_indexer, right_indexer)
981+
indexers into the left_keys, right_keys
945982
946983
"""
947984
from functools import partial
@@ -1345,6 +1382,13 @@ def _factorize_keys(lk, rk, sort=True):
13451382
if is_datetime64tz_dtype(lk) and is_datetime64tz_dtype(rk):
13461383
lk = lk.values
13471384
rk = rk.values
1385+
1386+
# if we exactly match in categories, allow us to use codes
1387+
if (is_categorical_dtype(lk) and
1388+
is_categorical_dtype(rk) and
1389+
lk.is_dtype_equal(rk)):
1390+
return lk.codes, rk.codes, len(lk.categories)
1391+
13481392
if is_int_or_datetime_dtype(lk) and is_int_or_datetime_dtype(rk):
13491393
klass = _hash.Int64Factorizer
13501394
lk = _ensure_int64(com._values_from_object(lk))

0 commit comments

Comments
 (0)