Skip to content

Commit 21f5fb1

Browse files
fjetterjreback
authored andcommitted
BUG: Fix KeyError in merge on CategoricalIndex (pandas-dev#20777)
1 parent d3d3352 commit 21f5fb1

File tree

4 files changed

+58
-25
lines changed

4 files changed

+58
-25
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,7 @@ Sparse
13271327
Reshaping
13281328
^^^^^^^^^
13291329

1330+
- Bug in :func:`DataFrame.merge` where referencing a ``CategoricalIndex`` by name, where the ``by`` kwarg would ``KeyError`` (:issue:`20777`)
13301331
- Bug in :func:`DataFrame.stack` which fails trying to sort mixed type levels under Python 3 (:issue:`18310`)
13311332
- Bug in :func:`DataFrame.unstack` which casts int to float if ``columns`` is a ``MultiIndex`` with unused levels (:issue:`17845`)
13321333
- Bug in :func:`DataFrame.unstack` which raises an error if ``index`` is a ``MultiIndex`` with unused labels on the unstacked level (:issue:`18562`)

pandas/core/algorithms.py

+2
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,8 @@ def take_nd(arr, indexer, axis=0, out=None, fill_value=np.nan, mask_info=None,
15851585

15861586
if is_sparse(arr):
15871587
arr = arr.get_values()
1588+
elif isinstance(arr, (ABCIndexClass, ABCSeries)):
1589+
arr = arr.values
15881590

15891591
arr = np.asarray(arr)
15901592

pandas/core/reshape/merge.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -705,8 +705,7 @@ def _maybe_add_join_keys(self, result, left_indexer, right_indexer):
705705
take_right = self.right[name]._values
706706

707707
elif left_indexer is not None \
708-
and isinstance(self.left_join_keys[i], np.ndarray):
709-
708+
and is_array_like(self.left_join_keys[i]):
710709
take_left = self.left_join_keys[i]
711710
take_right = self.right_join_keys[i]
712711

pandas/tests/reshape/merge/test_merge.py

+54-23
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
# pylint: disable=E1103
22

3-
import pytest
4-
from datetime import datetime, date
5-
from numpy.random import randn
6-
from numpy import nan
7-
import numpy as np
83
import random
94
import re
5+
from collections import OrderedDict
6+
from datetime import date, datetime
7+
8+
import numpy as np
9+
import pytest
10+
from numpy import nan
11+
from numpy.random import randn
1012

1113
import pandas as pd
14+
import pandas.util.testing as tm
15+
from pandas import (Categorical, CategoricalIndex, DataFrame, DatetimeIndex,
16+
Float64Index, Index, Int64Index, MultiIndex, RangeIndex,
17+
Series, UInt64Index)
18+
from pandas.api.types import CategoricalDtype as CDT
1219
from pandas.compat import lrange, lzip
20+
from pandas.core.dtypes.common import is_categorical_dtype, is_object_dtype
21+
from pandas.core.dtypes.dtypes import CategoricalDtype
1322
from pandas.core.reshape.concat import concat
14-
from pandas.core.reshape.merge import merge, MergeError
23+
from pandas.core.reshape.merge import MergeError, merge
1524
from pandas.util.testing import assert_frame_equal, assert_series_equal
16-
from pandas.core.dtypes.dtypes import CategoricalDtype
17-
from pandas.core.dtypes.common import (
18-
is_categorical_dtype,
19-
is_object_dtype,
20-
)
21-
from pandas import DataFrame, Index, MultiIndex, Series, Categorical
22-
import pandas.util.testing as tm
23-
from pandas.api.types import CategoricalDtype as CDT
2425

2526
N = 50
2627
NGROUPS = 8
@@ -813,7 +814,7 @@ def test_validation(self):
813814

814815
# Dups on right
815816
right_w_dups = right.append(pd.DataFrame({'a': ['e'], 'c': ['moo']},
816-
index=[4]))
817+
index=[4]))
817818
merge(left, right_w_dups, left_index=True, right_index=True,
818819
validate='one_to_many')
819820

@@ -1388,17 +1389,24 @@ def test_merge_datetime_index(self, klass):
13881389
if klass is not None:
13891390
on_vector = klass(on_vector)
13901391

1391-
expected = DataFrame({"a": [1, 2, 3]})
1392-
1393-
if klass == np.asarray:
1394-
# The join key is added for ndarray.
1395-
expected["key_1"] = [2016, 2017, 2018]
1392+
expected = DataFrame(
1393+
OrderedDict([
1394+
("a", [1, 2, 3]),
1395+
("key_1", [2016, 2017, 2018]),
1396+
])
1397+
)
13961398

13971399
result = df.merge(df, on=["a", on_vector], how="inner")
13981400
tm.assert_frame_equal(result, expected)
13991401

1400-
expected = DataFrame({"a_x": [1, 2, 3],
1401-
"a_y": [1, 2, 3]})
1402+
expected = DataFrame(
1403+
OrderedDict([
1404+
("key_0", [2016, 2017, 2018]),
1405+
("a_x", [1, 2, 3]),
1406+
("a_y", [1, 2, 3]),
1407+
])
1408+
)
1409+
14021410
result = df.merge(df, on=[df.index.year], how="inner")
14031411
tm.assert_frame_equal(result, expected)
14041412

@@ -1427,7 +1435,7 @@ def test_different(self, right_vals):
14271435
# We allow merging on object and categorical cols and cast
14281436
# categorical cols to object
14291437
if (is_categorical_dtype(right['A'].dtype) or
1430-
is_object_dtype(right['A'].dtype)):
1438+
is_object_dtype(right['A'].dtype)):
14311439
result = pd.merge(left, right, on='A')
14321440
assert is_object_dtype(result.A.dtype)
14331441

@@ -1826,3 +1834,26 @@ def test_merge_on_indexes(self, left_df, right_df, how, sort, expected):
18261834
how=how,
18271835
sort=sort)
18281836
tm.assert_frame_equal(result, expected)
1837+
1838+
1839+
@pytest.mark.parametrize(
1840+
'index', [
1841+
CategoricalIndex(['A', 'B'], categories=['A', 'B'], name='index_col'),
1842+
Float64Index([1.0, 2.0], name='index_col'),
1843+
Int64Index([1, 2], name='index_col'),
1844+
UInt64Index([1, 2], name='index_col'),
1845+
RangeIndex(start=0, stop=2, name='index_col'),
1846+
DatetimeIndex(["2018-01-01", "2018-01-02"], name='index_col'),
1847+
], ids=lambda x: type(x).__name__)
1848+
def test_merge_index_types(index):
1849+
# gh-20777
1850+
# assert key access is consistent across index types
1851+
left = DataFrame({"left_data": [1, 2]}, index=index)
1852+
right = DataFrame({"right_data": [1.0, 2.0]}, index=index)
1853+
1854+
result = left.merge(right, on=['index_col'])
1855+
1856+
expected = DataFrame(
1857+
OrderedDict([('left_data', [1, 2]), ('right_data', [1.0, 2.0])]),
1858+
index=index)
1859+
assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)