Skip to content

Commit a4b2ee6

Browse files
committed
BUG/API: .merge() and .join() on category dtype columns will now preserve the category dtype when possible
closes pandas-dev#10409
1 parent 5dee1f1 commit a4b2ee6

File tree

8 files changed

+288
-71
lines changed

8 files changed

+288
-71
lines changed

asv_bench/benchmarks/join_merge.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pandas import ordered_merge as merge_ordered
77

88

9-
#----------------------------------------------------------------------
9+
# ----------------------------------------------------------------------
1010
# Append
1111

1212
class Append(object):
@@ -35,7 +35,7 @@ def time_append_mixed(self):
3535
self.mdf1.append(self.mdf2)
3636

3737

38-
#----------------------------------------------------------------------
38+
# ----------------------------------------------------------------------
3939
# Concat
4040

4141
class Concat(object):
@@ -120,7 +120,7 @@ def time_f_ordered_axis1(self):
120120
concat(self.frames_f, axis=1, ignore_index=True)
121121

122122

123-
#----------------------------------------------------------------------
123+
# ----------------------------------------------------------------------
124124
# Joins
125125

126126
class Join(object):
@@ -202,7 +202,7 @@ def time_join_non_unique_equal(self):
202202
(self.fracofday * self.temp[self.fracofday.index])
203203

204204

205-
#----------------------------------------------------------------------
205+
# ----------------------------------------------------------------------
206206
# Merges
207207

208208
class Merge(object):
@@ -257,7 +257,31 @@ def time_i8merge(self):
257257
merge(self.left, self.right, how='outer')
258258

259259

260-
#----------------------------------------------------------------------
260+
class MergeCategoricals(object):
261+
goal_time = 0.2
262+
263+
def setup(self):
264+
self.left_object = pd.DataFrame(
265+
{'X': np.random.choice(range(0, 10), size=(10000,)),
266+
'Y': np.random.choice(['one', 'two', 'three'], size=(10000,))})
267+
268+
self.right_object = pd.DataFrame(
269+
{'X': np.random.choice(range(0, 10), size=(10000,)),
270+
'Z': np.random.choice(['jjj', 'kkk', 'sss'], size=(10000,))})
271+
272+
self.left_cat = self.left_object.assign(
273+
Y=self.left_object['Y'].astype('category'))
274+
self.right_cat = self.right_object.assign(
275+
Z=self.right_object['Z'].astype('category'))
276+
277+
def time_merge_object(self):
278+
merge(self.left_object, self.right_object, on='X')
279+
280+
def time_merge_cat(self):
281+
merge(self.left_cat, self.right_cat, on='X')
282+
283+
284+
# ----------------------------------------------------------------------
261285
# Ordered merge
262286

263287
class MergeOrdered(object):
@@ -332,7 +356,7 @@ def time_multiby(self):
332356
merge_asof(self.df1e, self.df2e, on='time', by=['key', 'key2'])
333357

334358

335-
#----------------------------------------------------------------------
359+
# ----------------------------------------------------------------------
336360
# data alignment
337361

338362
class Align(object):

doc/source/whatsnew/v0.20.0.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ Other API Changes
692692
- Reorganization of timeseries development tests (:issue:`14854`)
693693
- Specific support for ``copy.copy()`` and ``copy.deepcopy()`` functions on NDFrame objects (:issue:`15444`)
694694
- ``Series.sort_values()`` accepts a one element list of bool for consistency with the behavior of ``DataFrame.sort_values()`` (:issue:`15604`)
695-
- ``DataFrame.iterkv()`` has been removed in favor of ``DataFrame.iteritems()`` (:issue:`10711`)
695+
- ``.merge()`` and ``.join()`` on ``category`` dtype columns will now preserve the category dtype when possible (:issue:`10409`)
696696

697697
.. _whatsnew_0200.deprecations:
698698

@@ -733,6 +733,7 @@ Removal of prior version deprecations/changes
733733
- ``Series.is_time_series`` is dropped in favor of ``Series.index.is_all_dates`` (:issue:`15098`)
734734
- The deprecated ``irow``, ``icol``, ``iget`` and ``iget_value`` methods are removed
735735
in favor of ``iloc`` and ``iat`` as explained :ref:`here <whatsnew_0170.deprecations>` (:issue:`10711`).
736+
- The deprecated ``DataFrame.iterkv()`` has been removed in favor of ``DataFrame.iteritems()`` (:issue:`10711`)
736737

737738

738739
.. _whatsnew_0200.performance:
@@ -749,6 +750,7 @@ Performance Improvements
749750
- When reading buffer object in ``read_sas()`` method without specified format, filepath string is inferred rather than buffer object. (:issue:`14947`)
750751
- Improved performance of ``.rank()`` for categorical data (:issue:`15498`)
751752
- Improved performance when using ``.unstack()`` (:issue:`15503`)
753+
- Improved performance of merge/join on ``category`` columns (:issue:`10409`)
752754

753755

754756
.. _whatsnew_0200.bug_fixes:

pandas/core/internals.py

+2
Original file line numberDiff line numberDiff line change
@@ -5227,6 +5227,8 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
52275227
# External code requested filling/upcasting, bool values must
52285228
# be upcasted to object to avoid being upcasted to numeric.
52295229
values = self.block.astype(np.object_).values
5230+
elif self.block.is_categorical:
5231+
values = self.block.values
52305232
else:
52315233
# No dtype upcasting is done here, it will be performed during
52325234
# concatenation itself.

pandas/tests/test_categorical.py

+3
Original file line numberDiff line numberDiff line change
@@ -4097,9 +4097,12 @@ def test_merge(self):
40974097
expected = df.copy()
40984098

40994099
# object-cat
4100+
# note that we propogate the category
4101+
# because we don't have any matching rows
41004102
cright = right.copy()
41014103
cright['d'] = cright['d'].astype('category')
41024104
result = pd.merge(left, cright, how='left', left_on='b', right_on='c')
4105+
expected['d'] = expected['d'].astype('category', categories=['null'])
41034106
tm.assert_frame_equal(result, expected)
41044107

41054108
# cat-object

pandas/tests/tools/test_merge.py

+145-32
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=E1103
22

3+
import pytest
34
from datetime import datetime
45
from numpy.random import randn
56
from numpy import nan
@@ -11,6 +12,8 @@
1112
from pandas.tools.concat import concat
1213
from pandas.tools.merge import merge, MergeError
1314
from pandas.util.testing import assert_frame_equal, assert_series_equal
15+
from pandas.types.dtypes import CategoricalDtype
16+
from pandas.types.common import is_categorical_dtype, is_object_dtype
1417
from pandas import DataFrame, Index, MultiIndex, Series, Categorical
1518
import pandas.util.testing as tm
1619

@@ -1024,38 +1027,6 @@ def test_left_join_index_multi_match(self):
10241027
expected.index = np.arange(len(expected))
10251028
tm.assert_frame_equal(result, expected)
10261029

1027-
def test_join_multi_dtypes(self):
1028-
1029-
# test with multi dtypes in the join index
1030-
def _test(dtype1, dtype2):
1031-
left = DataFrame({'k1': np.array([0, 1, 2] * 8, dtype=dtype1),
1032-
'k2': ['foo', 'bar'] * 12,
1033-
'v': np.array(np.arange(24), dtype=np.int64)})
1034-
1035-
index = MultiIndex.from_tuples([(2, 'bar'), (1, 'foo')])
1036-
right = DataFrame(
1037-
{'v2': np.array([5, 7], dtype=dtype2)}, index=index)
1038-
1039-
result = left.join(right, on=['k1', 'k2'])
1040-
1041-
expected = left.copy()
1042-
1043-
if dtype2.kind == 'i':
1044-
dtype2 = np.dtype('float64')
1045-
expected['v2'] = np.array(np.nan, dtype=dtype2)
1046-
expected.loc[(expected.k1 == 2) & (expected.k2 == 'bar'), 'v2'] = 5
1047-
expected.loc[(expected.k1 == 1) & (expected.k2 == 'foo'), 'v2'] = 7
1048-
1049-
tm.assert_frame_equal(result, expected)
1050-
1051-
result = left.join(right, on=['k1', 'k2'], sort=True)
1052-
expected.sort_values(['k1', 'k2'], kind='mergesort', inplace=True)
1053-
tm.assert_frame_equal(result, expected)
1054-
1055-
for d1 in [np.int64, np.int32, np.int16, np.int8, np.uint8]:
1056-
for d2 in [np.int64, np.float64, np.float32, np.float16]:
1057-
_test(np.dtype(d1), np.dtype(d2))
1058-
10591030
def test_left_merge_na_buglet(self):
10601031
left = DataFrame({'id': list('abcde'), 'v1': randn(5),
10611032
'v2': randn(5), 'dummy': list('abcde'),
@@ -1242,3 +1213,145 @@ def f():
12421213
def f():
12431214
household.join(log_return, how='outer')
12441215
self.assertRaises(NotImplementedError, f)
1216+
1217+
1218+
@pytest.fixture
1219+
def df():
1220+
return DataFrame(
1221+
{'A': ['foo', 'bar'],
1222+
'B': Series(['foo', 'bar']).astype('category'),
1223+
'C': [1, 2],
1224+
'D': [1.0, 2.0],
1225+
'E': Series([1, 2], dtype='uint64'),
1226+
'F': Series([1, 2], dtype='int32')})
1227+
1228+
1229+
class TestMergeDtypes(object):
1230+
1231+
def test_different(self, df):
1232+
1233+
# we expect differences by kind
1234+
# to be ok, while other differences should return object
1235+
1236+
left = df
1237+
for col in df.columns:
1238+
right = DataFrame({'A': df[col]})
1239+
result = pd.merge(left, right, on='A')
1240+
assert is_object_dtype(result.A.dtype)
1241+
1242+
@pytest.mark.parametrize('d1', [np.int64, np.int32,
1243+
np.int16, np.int8, np.uint8])
1244+
@pytest.mark.parametrize('d2', [np.int64, np.float64,
1245+
np.float32, np.float16])
1246+
def test_join_multi_dtypes(self, d1, d2):
1247+
1248+
dtype1 = np.dtype(d1)
1249+
dtype2 = np.dtype(d2)
1250+
1251+
left = DataFrame({'k1': np.array([0, 1, 2] * 8, dtype=dtype1),
1252+
'k2': ['foo', 'bar'] * 12,
1253+
'v': np.array(np.arange(24), dtype=np.int64)})
1254+
1255+
index = MultiIndex.from_tuples([(2, 'bar'), (1, 'foo')])
1256+
right = DataFrame({'v2': np.array([5, 7], dtype=dtype2)}, index=index)
1257+
1258+
result = left.join(right, on=['k1', 'k2'])
1259+
1260+
expected = left.copy()
1261+
1262+
if dtype2.kind == 'i':
1263+
dtype2 = np.dtype('float64')
1264+
expected['v2'] = np.array(np.nan, dtype=dtype2)
1265+
expected.loc[(expected.k1 == 2) & (expected.k2 == 'bar'), 'v2'] = 5
1266+
expected.loc[(expected.k1 == 1) & (expected.k2 == 'foo'), 'v2'] = 7
1267+
1268+
tm.assert_frame_equal(result, expected)
1269+
1270+
result = left.join(right, on=['k1', 'k2'], sort=True)
1271+
expected.sort_values(['k1', 'k2'], kind='mergesort', inplace=True)
1272+
tm.assert_frame_equal(result, expected)
1273+
1274+
1275+
@pytest.fixture
1276+
def left():
1277+
np.random.seed(1234)
1278+
return DataFrame(
1279+
{'X': Series(np.random.choice(
1280+
['foo', 'bar'],
1281+
size=(10,))).astype('category', categories=['foo', 'bar']),
1282+
'Y': np.random.choice(['one', 'two', 'three'], size=(10,))})
1283+
1284+
1285+
@pytest.fixture
1286+
def right():
1287+
np.random.seed(1234)
1288+
return DataFrame(
1289+
{'X': Series(['foo', 'bar']).astype('category',
1290+
categories=['foo', 'bar']),
1291+
'Z': [1, 2]})
1292+
1293+
1294+
class TestMergeCategorical(object):
1295+
1296+
def test_identical(self, left):
1297+
# merging on the same, should preserve dtypes
1298+
merged = pd.merge(left, left, on='X')
1299+
result = merged.dtypes.sort_index()
1300+
expected = Series([CategoricalDtype(),
1301+
np.dtype('O'),
1302+
np.dtype('O')],
1303+
index=['X', 'Y_x', 'Y_y'])
1304+
assert_series_equal(result, expected)
1305+
1306+
def test_basic(self, left, right):
1307+
# we have matching Categorical dtypes in X
1308+
# so should preserve the merged column
1309+
merged = pd.merge(left, right, on='X')
1310+
result = merged.dtypes.sort_index()
1311+
expected = Series([CategoricalDtype(),
1312+
np.dtype('O'),
1313+
np.dtype('int64')],
1314+
index=['X', 'Y', 'Z'])
1315+
assert_series_equal(result, expected)
1316+
1317+
def test_other_columns(self, left, right):
1318+
# non-merge columns should preserve if possible
1319+
right = right.assign(Z=right.Z.astype('category'))
1320+
1321+
merged = pd.merge(left, right, on='X')
1322+
result = merged.dtypes.sort_index()
1323+
expected = Series([CategoricalDtype(),
1324+
np.dtype('O'),
1325+
CategoricalDtype()],
1326+
index=['X', 'Y', 'Z'])
1327+
assert_series_equal(result, expected)
1328+
1329+
# categories are preserved
1330+
assert left.X.values.is_dtype_equal(merged.X.values)
1331+
assert right.Z.values.is_dtype_equal(merged.Z.values)
1332+
1333+
@pytest.mark.parametrize(
1334+
'change', [lambda x: x,
1335+
lambda x: x.astype('category',
1336+
categories=['bar', 'foo']),
1337+
lambda x: x.astype('category',
1338+
categories=['foo', 'bar', 'bah']),
1339+
lambda x: x.astype('category', ordered=True)])
1340+
@pytest.mark.parametrize('how', ['inner', 'outer', 'left', 'right'])
1341+
def test_dtype_on_merged_different(self, change, how, left, right):
1342+
# our merging columns, X now has 2 different dtypes
1343+
# so we must be object as a result
1344+
1345+
X = change(right.X.astype('object'))
1346+
right = right.assign(X=X)
1347+
assert is_categorical_dtype(left.X.values)
1348+
assert not left.X.values.is_dtype_equal(right.X.values)
1349+
1350+
merged = pd.merge(left, right, on='X', how=how)
1351+
1352+
result = merged.dtypes.sort_index()
1353+
expected = Series([np.dtype('O'),
1354+
np.dtype('O'),
1355+
np.dtype('int64')],
1356+
index=['X', 'Y', 'Z'])
1357+
assert_series_equal(result, expected)

pandas/tests/tools/test_merge_asof.py

+1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def test_basic_categorical(self):
147147
trades.ticker = trades.ticker.astype('category')
148148
quotes = self.quotes.copy()
149149
quotes.ticker = quotes.ticker.astype('category')
150+
expected.ticker = expected.ticker.astype('category')
150151

151152
result = merge_asof(trades, quotes,
152153
on='time',

pandas/tests/types/test_common.py

+39-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3+
import pytest
34
import numpy as np
45

56
from pandas.types.dtypes import DatetimeTZDtype, PeriodDtype, CategoricalDtype
@@ -38,17 +39,44 @@ def test_period_dtype(self):
3839
self.assertEqual(pandas_dtype(dtype), dtype)
3940

4041

41-
def test_dtype_equal():
42-
assert is_dtype_equal(np.int64, np.int64)
43-
assert not is_dtype_equal(np.int64, np.float64)
42+
dtypes = dict(datetime_tz=pandas_dtype('datetime64[ns, US/Eastern]'),
43+
datetime=pandas_dtype('datetime64[ns]'),
44+
timedelta=pandas_dtype('timedelta64[ns]'),
45+
period=PeriodDtype('D'),
46+
integer=np.dtype(np.int64),
47+
float=np.dtype(np.float64),
48+
object=np.dtype(np.object),
49+
category=pandas_dtype('category'))
4450

45-
p1 = PeriodDtype('D')
46-
p2 = PeriodDtype('D')
47-
assert is_dtype_equal(p1, p2)
48-
assert not is_dtype_equal(np.int64, p1)
4951

50-
p3 = PeriodDtype('2D')
51-
assert not is_dtype_equal(p1, p3)
52+
@pytest.mark.parametrize('name1,dtype1',
53+
list(dtypes.items()),
54+
ids=lambda x: str(x))
55+
@pytest.mark.parametrize('name2,dtype2',
56+
list(dtypes.items()),
57+
ids=lambda x: str(x))
58+
def test_dtype_equal(name1, dtype1, name2, dtype2):
5259

53-
assert not DatetimeTZDtype.is_dtype(np.int64)
54-
assert not PeriodDtype.is_dtype(np.int64)
60+
# match equal to self, but not equal to other
61+
assert is_dtype_equal(dtype1, dtype1)
62+
if name1 != name2:
63+
assert not is_dtype_equal(dtype1, dtype2)
64+
65+
66+
def test_dtype_equal_strict():
67+
68+
# we are strict on kind equality
69+
for dtype in [np.int8, np.int16, np.int32]:
70+
assert not is_dtype_equal(np.int64, dtype)
71+
72+
for dtype in [np.float32]:
73+
assert not is_dtype_equal(np.float64, dtype)
74+
75+
# strict w.r.t. PeriodDtype
76+
assert not is_dtype_equal(PeriodDtype('D'),
77+
PeriodDtype('2D'))
78+
79+
# strict w.r.t. datetime64
80+
assert not is_dtype_equal(
81+
pandas_dtype('datetime64[ns, US/Eastern]'),
82+
pandas_dtype('datetime64[ns, CET]'))

0 commit comments

Comments
 (0)