Skip to content

Commit e28b353

Browse files
tptopper-123
tp
authored andcommitted
make Int8/16/32Engine work with Int64HashTable
1 parent dc4b6f8 commit e28b353

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

pandas/_libs/algos_common_helper.pxi.in

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dtypes = [('float64', 'float64_t', 'np.float64', True, True),
2828
('int64', 'int64_t', 'np.int64', False, True),
2929
('uint64', 'uint64_t', 'np.uint64', False, True),
3030
('int32', 'int32_t', 'np.int32', False, True),
31+
('int16', 'int16_t', 'np.int16', False, True),
3132
('int8', 'int8_t', 'np.int8', False, True),
3233
('bool', 'uint8_t', 'np.bool', False, True)]
3334

pandas/_libs/index.pyx

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ from cpython.slice cimport PySlice_Check
99
import numpy as np
1010
cimport numpy as cnp
1111
from numpy cimport (ndarray, float64_t, int32_t,
12-
int64_t, uint8_t, uint64_t, intp_t,
12+
int8_t, int16_t, int32_t, int64_t,
13+
uint8_t, uint64_t,
14+
intp_t,
1315
# Note: NPY_DATETIME, NPY_TIMEDELTA are only available
1416
# for cimport in cython>=0.27.3
1517
NPY_DATETIME, NPY_TIMEDELTA)
@@ -264,6 +266,8 @@ cdef class IndexEngine:
264266
if not self.is_mapping_populated:
265267

266268
values = self._get_index_values()
269+
if values.dtype in {'int8', 'int16', 'int32'}:
270+
values = algos.ensure_int64(values)
267271
self.mapping = self._make_hash_table(len(values))
268272
self._call_map_locations(values)
269273

pandas/_libs/index_class_helper.pxi.in

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ cdef class {{name}}Engine(IndexEngine):
4040
cdef _make_hash_table(self, n):
4141
{{if name == 'Object'}}
4242
return _hash.PyObjectHashTable(n)
43+
{{elif name in {'Int8', 'Int16', 'Int32'} }}
44+
return _hash.Int64HashTable(n)
4345
{{else}}
4446
return _hash.{{name}}HashTable(n)
4547
{{endif}}

pandas/core/indexes/category.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -437,18 +437,12 @@ def get_loc(self, key, method=None):
437437
>>> non_monotonic_index.get_loc('b')
438438
array([False, True, False, True], dtype=bool)
439439
"""
440-
codes = self.categories.get_loc(key)
441-
if (codes == -1):
440+
code = self.categories.get_loc(key)
441+
if (code == -1):
442442
raise KeyError(key)
443443

444-
if self.is_monotonic_increasing and not self.is_unique:
445-
if codes not in self._engine:
446-
raise KeyError(key)
447-
codes = self.codes.dtype.type(codes)
448-
lhs = self.codes.searchsorted(codes, side='left')
449-
rhs = self.codes.searchsorted(codes, side='right')
450-
return slice(lhs, rhs)
451-
return self._engine.get_loc(codes)
444+
code = self.codes.dtype.type(code)
445+
return self._engine.get_loc(code)
452446

453447
def get_value(self, series, key):
454448
"""

0 commit comments

Comments
 (0)