Skip to content

Commit 0ae7e90

Browse files
jorisvandenbosschejreback
authored andcommitted
Fix pd.merge to preserve ExtensionArrays dtypes (pandas-dev#20745)
1 parent 4de2e9b commit 0ae7e90

File tree

5 files changed

+43
-4
lines changed

5 files changed

+43
-4
lines changed

pandas/core/dtypes/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1807,7 +1807,7 @@ def _get_dtype(arr_or_dtype):
18071807
return arr_or_dtype
18081808
elif isinstance(arr_or_dtype, type):
18091809
return np.dtype(arr_or_dtype)
1810-
elif isinstance(arr_or_dtype, CategoricalDtype):
1810+
elif isinstance(arr_or_dtype, ExtensionDtype):
18111811
return arr_or_dtype
18121812
elif isinstance(arr_or_dtype, DatetimeTZDtype):
18131813
return arr_or_dtype

pandas/core/internals.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -5541,8 +5541,14 @@ def concatenate_join_units(join_units, concat_axis, copy):
55415541
if len(to_concat) == 1:
55425542
# Only one block, nothing to concatenate.
55435543
concat_values = to_concat[0]
5544-
if copy and concat_values.base is not None:
5545-
concat_values = concat_values.copy()
5544+
if copy:
5545+
if isinstance(concat_values, np.ndarray):
5546+
# non-reindexed (=not yet copied) arrays are made into a view
5547+
# in JoinUnit.get_reindexed_values
5548+
if concat_values.base is not None:
5549+
concat_values = concat_values.copy()
5550+
else:
5551+
concat_values = concat_values.copy()
55465552
else:
55475553
concat_values = _concat._concat_compat(to_concat, axis=concat_axis)
55485554

@@ -5823,7 +5829,7 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
58235829
# External code requested filling/upcasting, bool values must
58245830
# be upcasted to object to avoid being upcasted to numeric.
58255831
values = self.block.astype(np.object_).values
5826-
elif self.block.is_categorical:
5832+
elif self.block.is_extension:
58275833
values = self.block.values
58285834
else:
58295835
# No dtype upcasting is done here, it will be performed during

pandas/tests/extension/base/reshaping.py

+21
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,24 @@ def test_set_frame_overwrite_object(self, data):
9595
df = pd.DataFrame({"A": [1] * len(data)}, dtype=object)
9696
df['A'] = data
9797
assert df.dtypes['A'] == data.dtype
98+
99+
def test_merge(self, data, na_value):
100+
# GH-20743
101+
df1 = pd.DataFrame({'ext': data[:3], 'int1': [1, 2, 3],
102+
'key': [0, 1, 2]})
103+
df2 = pd.DataFrame({'int2': [1, 2, 3, 4], 'key': [0, 0, 1, 3]})
104+
105+
res = pd.merge(df1, df2)
106+
exp = pd.DataFrame(
107+
{'int1': [1, 1, 2], 'int2': [1, 2, 3], 'key': [0, 0, 1],
108+
'ext': data._constructor_from_sequence(
109+
[data[0], data[0], data[1]])})
110+
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])
111+
112+
res = pd.merge(df1, df2, how='outer')
113+
exp = pd.DataFrame(
114+
{'int1': [1, 1, 2, 3, np.nan], 'int2': [1, 2, 3, np.nan, 4],
115+
'key': [0, 0, 1, 2, 3],
116+
'ext': data._constructor_from_sequence(
117+
[data[0], data[0], data[1], data[2], na_value])})
118+
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])

pandas/tests/extension/category/test_categorical.py

+4
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def test_align(self, data, na_value):
7575
def test_align_frame(self, data, na_value):
7676
pass
7777

78+
@pytest.mark.skip(reason="Unobserved categories preseved in concat.")
79+
def test_merge(self, data, na_value):
80+
pass
81+
7882

7983
class TestGetitem(base.BaseGetitemTests):
8084
@pytest.mark.skip(reason="Backwards compatibility")

pandas/tests/extension/decimal/test_decimal.py

+8
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def assert_series_equal(self, left, right, *args, **kwargs):
7272

7373
def assert_frame_equal(self, left, right, *args, **kwargs):
7474
# TODO(EA): select_dtypes
75+
tm.assert_index_equal(
76+
left.columns, right.columns,
77+
exact=kwargs.get('check_column_type', 'equiv'),
78+
check_names=kwargs.get('check_names', True),
79+
check_exact=kwargs.get('check_exact', False),
80+
check_categorical=kwargs.get('check_categorical', True),
81+
obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame')))
82+
7583
decimals = (left.dtypes == 'decimal').index
7684

7785
for col in decimals:

0 commit comments

Comments
 (0)