Skip to content

Commit 63dd41f

Browse files
committed
BUG/API: .merge() and .join() on category dtype columns will now preserve the category dtype when possible
closes pandas-dev#10409
1 parent a0f7fc0 commit 63dd41f

File tree

8 files changed

+263
-63
lines changed

8 files changed

+263
-63
lines changed

asv_bench/benchmarks/join_merge.py

+24
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,30 @@ def time_i8merge(self):
257257
merge(self.left, self.right, how='outer')
258258

259259

260+
class MergeCategoricals(object):
261+
goal_time = 0.2
262+
263+
def setup(self):
264+
self.left_object = pd.DataFrame(
265+
{'X': np.random.choice(range(0, 10), size=(10000,)),
266+
'Y': np.random.choice(['one', 'two', 'three'], size=(10000,))})
267+
268+
self.right_object = pd.DataFrame(
269+
{'X': np.random.choice(range(0, 10), size=(10000,)),
270+
'Z': np.random.choice(['jjj', 'kkk', 'sss'], size=(10000,))})
271+
272+
self.left_cat = self.left_object.assign(
273+
Y=self.left_object['Y'].astype('category'))
274+
self.right_cat = self.right_object.assign(
275+
Z=self.right_object['Z'].astype('category'))
276+
277+
def time_merge_object(self):
278+
merge(self.left_object, self.right_object, on='X')
279+
280+
def time_merge_cat(self):
281+
merge(self.left_cat, self.right_cat, on='X')
282+
283+
260284
#----------------------------------------------------------------------
261285
# Ordered merge
262286

doc/source/whatsnew/v0.20.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ Other API Changes
428428
- ``DataFrame.asof()`` will return a null filled ``Series`` instead the scalar ``NaN`` if a match is not found (:issue:`15118`)
429429
- The :func:`pd.read_gbq` method now stores ``INTEGER`` columns as ``dtype=object`` if they contain ``NULL`` values. Otherwise they are stored as ``int64``. This prevents precision lost for integers greather than 2**53. Furthermore ``FLOAT`` columns with values above 10**4 are no more casted to ``int64`` which also caused precision lost (:issue: `14064`, :issue:`14305`).
430430
- Reorganization of timeseries development tests (:issue:`14854`)
431+
- ``.merge()`` and ``.join()`` on ``category`` dtype columns will now preserve the category dtype when possible (:issue:`10409`)
431432

432433
.. _whatsnew_0200.deprecations:
433434

@@ -469,6 +470,7 @@ Performance Improvements
469470
- Improved performance of timeseries plotting with an irregular DatetimeIndex
470471
(or with ``compat_x=True``) (:issue:`15073`).
471472
- Improved performance of ``groupby().cummin()`` and ``groupby().cummax()`` (:issue:`15048`, :issue:`15109`)
473+
- Improved performance of merge/join on ``category`` columns (:issue:`10409`)
472474

473475
- When reading buffer object in ``read_sas()`` method without specified format, filepath string is inferred rather than buffer object.
474476

pandas/core/internals.py

+2
Original file line numberDiff line numberDiff line change
@@ -5224,6 +5224,8 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
52245224
# External code requested filling/upcasting, bool values must
52255225
# be upcasted to object to avoid being upcasted to numeric.
52265226
values = self.block.astype(np.object_).values
5227+
elif self.block.is_categorical:
5228+
values = self.block.values
52275229
else:
52285230
# No dtype upcasting is done here, it will be performed during
52295231
# concatenation itself.

pandas/tests/test_categorical.py

+3
Original file line numberDiff line numberDiff line change
@@ -4097,9 +4097,12 @@ def test_merge(self):
40974097
expected = df.copy()
40984098

40994099
# object-cat
4100+
# note that we propogate the category
4101+
# because we don't have any matching rows
41004102
cright = right.copy()
41014103
cright['d'] = cright['d'].astype('category')
41024104
result = pd.merge(left, cright, how='left', left_on='b', right_on='c')
4105+
expected['d'] = expected['d'].astype('category', categories=['null'])
41034106
tm.assert_frame_equal(result, expected)
41044107

41054108
# cat-object

pandas/tests/tools/test_merge.py

+139-32
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from pandas.util.testing import (assert_frame_equal,
1414
assert_series_equal,
1515
slow)
16+
from pandas.types.dtypes import CategoricalDtype
17+
from pandas.types.common import is_categorical_dtype, is_object_dtype
1618
from pandas import DataFrame, Index, MultiIndex, Series, Categorical
1719
import pandas.util.testing as tm
1820

@@ -1018,38 +1020,6 @@ def test_left_join_index_multi_match(self):
10181020
expected.index = np.arange(len(expected))
10191021
tm.assert_frame_equal(result, expected)
10201022

1021-
def test_join_multi_dtypes(self):
1022-
1023-
# test with multi dtypes in the join index
1024-
def _test(dtype1, dtype2):
1025-
left = DataFrame({'k1': np.array([0, 1, 2] * 8, dtype=dtype1),
1026-
'k2': ['foo', 'bar'] * 12,
1027-
'v': np.array(np.arange(24), dtype=np.int64)})
1028-
1029-
index = MultiIndex.from_tuples([(2, 'bar'), (1, 'foo')])
1030-
right = DataFrame(
1031-
{'v2': np.array([5, 7], dtype=dtype2)}, index=index)
1032-
1033-
result = left.join(right, on=['k1', 'k2'])
1034-
1035-
expected = left.copy()
1036-
1037-
if dtype2.kind == 'i':
1038-
dtype2 = np.dtype('float64')
1039-
expected['v2'] = np.array(np.nan, dtype=dtype2)
1040-
expected.loc[(expected.k1 == 2) & (expected.k2 == 'bar'), 'v2'] = 5
1041-
expected.loc[(expected.k1 == 1) & (expected.k2 == 'foo'), 'v2'] = 7
1042-
1043-
tm.assert_frame_equal(result, expected)
1044-
1045-
result = left.join(right, on=['k1', 'k2'], sort=True)
1046-
expected.sort_values(['k1', 'k2'], kind='mergesort', inplace=True)
1047-
tm.assert_frame_equal(result, expected)
1048-
1049-
for d1 in [np.int64, np.int32, np.int16, np.int8, np.uint8]:
1050-
for d2 in [np.int64, np.float64, np.float32, np.float16]:
1051-
_test(np.dtype(d1), np.dtype(d2))
1052-
10531023
def test_left_merge_na_buglet(self):
10541024
left = DataFrame({'id': list('abcde'), 'v1': randn(5),
10551025
'v2': randn(5), 'dummy': list('abcde'),
@@ -1367,3 +1337,140 @@ def f():
13671337
def f():
13681338
household.join(log_return, how='outer')
13691339
self.assertRaises(NotImplementedError, f)
1340+
1341+
1342+
class TestMergeDtypes(tm.TestCase):
1343+
1344+
def setUp(self):
1345+
1346+
self.df = DataFrame(
1347+
{'A': ['foo', 'bar'],
1348+
'B': Series(['foo', 'bar']).astype('category'),
1349+
'C': [1, 2],
1350+
'D': [1.0, 2.0],
1351+
'E': Series([1, 2], dtype='uint64'),
1352+
'F': Series([1, 2], dtype='int32')})
1353+
1354+
def test_different(self):
1355+
1356+
# we expect differences by kind
1357+
# to be ok, while other differences should return object
1358+
1359+
left = self.df
1360+
for col in self.df.columns:
1361+
right = DataFrame({'A': self.df[col]})
1362+
result = pd.merge(left, right, on='A')
1363+
self.assertTrue(is_object_dtype(result.A.dtype))
1364+
1365+
def test_join_multi_dtypes(self):
1366+
1367+
# test with multi dtypes in the join index
1368+
def _test(dtype1, dtype2):
1369+
left = DataFrame({'k1': np.array([0, 1, 2] * 8, dtype=dtype1),
1370+
'k2': ['foo', 'bar'] * 12,
1371+
'v': np.array(np.arange(24), dtype=np.int64)})
1372+
1373+
index = MultiIndex.from_tuples([(2, 'bar'), (1, 'foo')])
1374+
right = DataFrame(
1375+
{'v2': np.array([5, 7], dtype=dtype2)}, index=index)
1376+
1377+
result = left.join(right, on=['k1', 'k2'])
1378+
1379+
expected = left.copy()
1380+
1381+
if dtype2.kind == 'i':
1382+
dtype2 = np.dtype('float64')
1383+
expected['v2'] = np.array(np.nan, dtype=dtype2)
1384+
expected.loc[(expected.k1 == 2) & (expected.k2 == 'bar'), 'v2'] = 5
1385+
expected.loc[(expected.k1 == 1) & (expected.k2 == 'foo'), 'v2'] = 7
1386+
1387+
tm.assert_frame_equal(result, expected)
1388+
1389+
result = left.join(right, on=['k1', 'k2'], sort=True)
1390+
expected.sort_values(['k1', 'k2'], kind='mergesort', inplace=True)
1391+
tm.assert_frame_equal(result, expected)
1392+
1393+
for d1 in [np.int64, np.int32, np.int16, np.int8, np.uint8]:
1394+
for d2 in [np.int64, np.float64, np.float32, np.float16]:
1395+
_test(np.dtype(d1), np.dtype(d2))
1396+
1397+
1398+
class TestMergeCategorical(tm.TestCase):
1399+
_multiprocess_can_split_ = True
1400+
1401+
def setUp(self):
1402+
np.random.seed(1234)
1403+
self.left = DataFrame(
1404+
{'X': Series(np.random.choice(
1405+
['foo', 'bar'],
1406+
size=(10,))).astype('category', categories=['foo', 'bar']),
1407+
'Y': np.random.choice(['one', 'two', 'three'], size=(10,))})
1408+
self.right = pd.DataFrame(
1409+
{'X': Series(['foo', 'bar']).astype('category',
1410+
categories=['foo', 'bar']),
1411+
'Z': [1, 2]})
1412+
1413+
def test_identical(self):
1414+
# merging on the same, should preserve dtypes
1415+
merged = pd.merge(self.left, self.left, on='X')
1416+
result = merged.dtypes.sort_index()
1417+
expected = Series([CategoricalDtype(),
1418+
np.dtype('O'),
1419+
np.dtype('O')],
1420+
index=['X', 'Y_x', 'Y_y'])
1421+
assert_series_equal(result, expected)
1422+
1423+
def test_basic(self):
1424+
# we have matching Categorical dtypes in X
1425+
# so should preserve the merged column
1426+
merged = pd.merge(self.left, self.right, on='X')
1427+
result = merged.dtypes.sort_index()
1428+
expected = Series([CategoricalDtype(),
1429+
np.dtype('O'),
1430+
np.dtype('int64')],
1431+
index=['X', 'Y', 'Z'])
1432+
assert_series_equal(result, expected)
1433+
1434+
def test_other_columns(self):
1435+
# non-merge columns should preserve if possible
1436+
left = self.left
1437+
right = self.right.assign(Z=self.right.Z.astype('category'))
1438+
1439+
merged = pd.merge(left, right, on='X')
1440+
result = merged.dtypes.sort_index()
1441+
expected = Series([CategoricalDtype(),
1442+
np.dtype('O'),
1443+
CategoricalDtype()],
1444+
index=['X', 'Y', 'Z'])
1445+
assert_series_equal(result, expected)
1446+
1447+
# categories are preserved
1448+
self.assertTrue(left.X.values.is_dtype_equal(merged.X.values))
1449+
self.assertTrue(right.Z.values.is_dtype_equal(merged.Z.values))
1450+
1451+
def test_dtype_on_merged_different(self):
1452+
# our merging columns, X now has 2 different dtypes
1453+
# so we must be object as a result
1454+
left = self.left
1455+
1456+
for change in [lambda x: x,
1457+
lambda x: x.astype('category',
1458+
categories=['bar', 'foo']),
1459+
lambda x: x.astype('category',
1460+
categories=['foo', 'bar', 'bah']),
1461+
lambda x: x.astype('category', ordered=True)]:
1462+
for how in ['inner', 'outer', 'left', 'right']:
1463+
1464+
X = change(self.right.X.astype('object'))
1465+
right = self.right.assign(X=X)
1466+
self.assertTrue(is_categorical_dtype(left.X.values))
1467+
self.assertFalse(left.X.values.is_dtype_equal(right.X.values))
1468+
1469+
merged = pd.merge(left, right, on='X', how=how)
1470+
1471+
result = merged.dtypes.sort_index()
1472+
expected = Series([np.dtype('O'),
1473+
np.dtype('O'),
1474+
np.dtype('int64')],
1475+
index=['X', 'Y', 'Z'])
1476+
assert_series_equal(result, expected)

pandas/tests/tools/test_merge_asof.py

+1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def test_basic_categorical(self):
147147
trades.ticker = trades.ticker.astype('category')
148148
quotes = self.quotes.copy()
149149
quotes.ticker = quotes.ticker.astype('category')
150+
expected.ticker = expected.ticker.astype('category')
150151

151152
result = merge_asof(trades, quotes,
152153
on='time',

pandas/tests/types/test_common.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,33 @@ def test_period_dtype(self):
3939

4040

4141
def test_dtype_equal():
42-
assert is_dtype_equal(np.int64, np.int64)
43-
assert not is_dtype_equal(np.int64, np.float64)
4442

45-
p1 = PeriodDtype('D')
46-
p2 = PeriodDtype('D')
47-
assert is_dtype_equal(p1, p2)
48-
assert not is_dtype_equal(np.int64, p1)
43+
dtypes = dict(dt_tz=pandas_dtype('datetime64[ns, US/Eastern]'),
44+
dt=pandas_dtype('datetime64[ns]'),
45+
td=pandas_dtype('timedelta64[ns]'),
46+
p=PeriodDtype('D'),
47+
i=np.int64,
48+
f=np.float64,
49+
o=np.object)
4950

50-
p3 = PeriodDtype('2D')
51-
assert not is_dtype_equal(p1, p3)
51+
# match equal to self, but not equal to other
52+
for name, dtype in dtypes.items():
53+
assert is_dtype_equal(dtype, dtype)
5254

53-
assert not DatetimeTZDtype.is_dtype(np.int64)
54-
assert not PeriodDtype.is_dtype(np.int64)
55+
for name2, dtype2 in dtypes.items():
56+
if name != name2:
57+
assert not is_dtype_equal(dtype, dtype2)
58+
59+
# we are strict on kind equality
60+
for dtype in [np.int8, np.int16, np.int32]:
61+
assert not is_dtype_equal(dtypes['i'], dtype)
62+
63+
for dtype in [np.float32]:
64+
assert not is_dtype_equal(dtypes['f'], dtype)
65+
66+
# strict w.r.t. PeriodDtype
67+
assert not is_dtype_equal(dtypes['p'], PeriodDtype('2D'))
68+
69+
# strict w.r.t. datetime64
70+
assert not is_dtype_equal(dtypes['dt_tz'],
71+
pandas_dtype('datetime64[ns, CET]'))

0 commit comments

Comments
 (0)