Skip to content

Commit c5af9ac

Browse files
jbrockmendelJulianWgs
authored andcommitted
CLN: factorize returns ndarray[intp], not int64 (pandas-dev#40474)
1 parent 77126a3 commit c5af9ac

File tree

5 files changed

+65
-48
lines changed

5 files changed

+65
-48
lines changed

pandas/_libs/hashtable.pyx

+18-8
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,28 @@ cdef class Factorizer:
6666
self.uniques = ObjectVector()
6767
self.count = 0
6868

69-
def get_count(self):
69+
def get_count(self) -> int:
7070
return self.count
7171

7272
def factorize(
7373
self, ndarray[object] values, sort=False, na_sentinel=-1, na_value=None
74-
):
74+
) -> np.ndarray:
7575
"""
76+
77+
Returns
78+
-------
79+
np.ndarray[np.intp]
80+
7681
Examples
7782
--------
7883
Factorize values with nans replaced by na_sentinel
7984

8085
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
8186
array([ 0, 1, 20])
8287
"""
88+
cdef:
89+
ndarray[intp_t] labels
90+
8391
if self.uniques.external_view_exists:
8492
uniques = ObjectVector()
8593
uniques.extend(self.uniques.to_array())
@@ -89,8 +97,6 @@ cdef class Factorizer:
8997
mask = (labels == na_sentinel)
9098
# sort on
9199
if sort:
92-
if labels.dtype != np.intp:
93-
labels = labels.astype(np.intp)
94100
sorter = self.uniques.to_array().argsort()
95101
reverse_indexer = np.empty(len(sorter), dtype=np.intp)
96102
reverse_indexer.put(sorter, np.arange(len(sorter)))
@@ -119,15 +125,22 @@ cdef class Int64Factorizer:
119125
return self.count
120126

121127
def factorize(self, const int64_t[:] values, sort=False,
122-
na_sentinel=-1, na_value=None):
128+
na_sentinel=-1, na_value=None) -> np.ndarray:
123129
"""
130+
Returns
131+
-------
132+
ndarray[intp_t]
133+
124134
Examples
125135
--------
126136
Factorize values with nans replaced by na_sentinel
127137

128138
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
129139
array([ 0, 1, 20])
130140
"""
141+
cdef:
142+
ndarray[intp_t] labels
143+
131144
if self.uniques.external_view_exists:
132145
uniques = Int64Vector()
133146
uniques.extend(self.uniques.to_array())
@@ -138,9 +151,6 @@ cdef class Int64Factorizer:
138151

139152
# sort on
140153
if sort:
141-
if labels.dtype != np.intp:
142-
labels = labels.astype(np.intp)
143-
144154
sorter = self.uniques.to_array().argsort()
145155
reverse_indexer = np.empty(len(sorter), dtype=np.intp)
146156
reverse_indexer.put(sorter, np.arange(len(sorter)))

pandas/_libs/hashtable_class_helper.pxi.in

+18-18
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,12 @@ cdef class {{name}}HashTable(HashTable):
539539
-------
540540
uniques : ndarray[{{dtype}}]
541541
Unique values of input, not sorted
542-
labels : ndarray[int64] (if return_inverse=True)
542+
labels : ndarray[intp_t] (if return_inverse=True)
543543
The labels from values to uniques
544544
"""
545545
cdef:
546546
Py_ssize_t i, idx, count = count_prior, n = len(values)
547-
int64_t[:] labels
547+
intp_t[:] labels
548548
int ret = 0
549549
{{c_type}} val, na_value2
550550
khiter_t k
@@ -553,7 +553,7 @@ cdef class {{name}}HashTable(HashTable):
553553
uint8_t[:] mask_values
554554

555555
if return_inverse:
556-
labels = np.empty(n, dtype=np.int64)
556+
labels = np.empty(n, dtype=np.intp)
557557
ud = uniques.data
558558
use_na_value = na_value is not None
559559
use_mask = mask is not None
@@ -614,7 +614,7 @@ cdef class {{name}}HashTable(HashTable):
614614
labels[i] = idx
615615

616616
if return_inverse:
617-
return uniques.to_array(), np.asarray(labels)
617+
return uniques.to_array(), labels.base # .base -> underlying ndarray
618618
return uniques.to_array()
619619

620620
def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False):
@@ -633,7 +633,7 @@ cdef class {{name}}HashTable(HashTable):
633633
-------
634634
uniques : ndarray[{{dtype}}]
635635
Unique values of input, not sorted
636-
labels : ndarray[int64] (if return_inverse)
636+
labels : ndarray[intp_t] (if return_inverse)
637637
The labels from values to uniques
638638
"""
639639
uniques = {{name}}Vector()
@@ -668,7 +668,7 @@ cdef class {{name}}HashTable(HashTable):
668668
-------
669669
uniques : ndarray[{{dtype}}]
670670
Unique values of input, not sorted
671-
labels : ndarray[int64]
671+
labels : ndarray[intp_t]
672672
The labels from values to uniques
673673
"""
674674
uniques_vector = {{name}}Vector()
@@ -918,12 +918,12 @@ cdef class StringHashTable(HashTable):
918918
-------
919919
uniques : ndarray[object]
920920
Unique values of input, not sorted
921-
labels : ndarray[int64] (if return_inverse=True)
921+
labels : ndarray[intp_t] (if return_inverse=True)
922922
The labels from values to uniques
923923
"""
924924
cdef:
925925
Py_ssize_t i, idx, count = count_prior, n = len(values)
926-
int64_t[:] labels
926+
intp_t[:] labels
927927
int64_t[:] uindexer
928928
int ret = 0
929929
object val
@@ -933,7 +933,7 @@ cdef class StringHashTable(HashTable):
933933
bint use_na_value
934934

935935
if return_inverse:
936-
labels = np.zeros(n, dtype=np.int64)
936+
labels = np.zeros(n, dtype=np.intp)
937937
uindexer = np.empty(n, dtype=np.int64)
938938
use_na_value = na_value is not None
939939

@@ -972,13 +972,13 @@ cdef class StringHashTable(HashTable):
972972
uindexer[count] = i
973973
if return_inverse:
974974
self.table.vals[k] = count
975-
labels[i] = <int64_t>count
975+
labels[i] = count
976976
count += 1
977977
elif return_inverse:
978978
# k falls into a previous bucket
979979
# only relevant in case we need to construct the inverse
980980
idx = self.table.vals[k]
981-
labels[i] = <int64_t>idx
981+
labels[i] = idx
982982

983983
free(vecs)
984984

@@ -987,7 +987,7 @@ cdef class StringHashTable(HashTable):
987987
uniques.append(values[uindexer[i]])
988988

989989
if return_inverse:
990-
return uniques.to_array(), np.asarray(labels)
990+
return uniques.to_array(), labels.base # .base -> underlying ndarray
991991
return uniques.to_array()
992992

993993
def unique(self, ndarray[object] values, bint return_inverse=False):
@@ -1193,19 +1193,19 @@ cdef class PyObjectHashTable(HashTable):
11931193
-------
11941194
uniques : ndarray[object]
11951195
Unique values of input, not sorted
1196-
labels : ndarray[int64] (if return_inverse=True)
1196+
labels : ndarray[intp_t] (if return_inverse=True)
11971197
The labels from values to uniques
11981198
"""
11991199
cdef:
12001200
Py_ssize_t i, idx, count = count_prior, n = len(values)
1201-
int64_t[:] labels
1201+
intp_t[:] labels
12021202
int ret = 0
12031203
object val
12041204
khiter_t k
12051205
bint use_na_value
12061206

12071207
if return_inverse:
1208-
labels = np.empty(n, dtype=np.int64)
1208+
labels = np.empty(n, dtype=np.intp)
12091209
use_na_value = na_value is not None
12101210

12111211
for i in range(n):
@@ -1240,7 +1240,7 @@ cdef class PyObjectHashTable(HashTable):
12401240
labels[i] = idx
12411241

12421242
if return_inverse:
1243-
return uniques.to_array(), np.asarray(labels)
1243+
return uniques.to_array(), labels.base # .base -> underlying ndarray
12441244
return uniques.to_array()
12451245

12461246
def unique(self, ndarray[object] values, bint return_inverse=False):
@@ -1259,7 +1259,7 @@ cdef class PyObjectHashTable(HashTable):
12591259
-------
12601260
uniques : ndarray[object]
12611261
Unique values of input, not sorted
1262-
labels : ndarray[int64] (if return_inverse)
1262+
labels : ndarray[intp_t] (if return_inverse)
12631263
The labels from values to uniques
12641264
"""
12651265
uniques = ObjectVector()
@@ -1292,7 +1292,7 @@ cdef class PyObjectHashTable(HashTable):
12921292
-------
12931293
uniques : ndarray[object]
12941294
Unique values of input, not sorted
1295-
labels : ndarray[int64]
1295+
labels : ndarray[intp_t]
12961296
The labels from values to uniques
12971297
"""
12981298
uniques_vector = ObjectVector()

pandas/_libs/join.pyx

+11-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from numpy cimport (
1010
int16_t,
1111
int32_t,
1212
int64_t,
13+
intp_t,
1314
ndarray,
1415
uint8_t,
1516
uint16_t,
@@ -20,14 +21,15 @@ from numpy cimport (
2021
cnp.import_array()
2122

2223
from pandas._libs.algos import (
24+
ensure_int64,
2325
ensure_platform_int,
2426
groupsort_indexer,
2527
take_1d_int64_int64,
2628
)
2729

2830

2931
@cython.boundscheck(False)
30-
def inner_join(const int64_t[:] left, const int64_t[:] right,
32+
def inner_join(const intp_t[:] left, const intp_t[:] right,
3133
Py_ssize_t max_groups):
3234
cdef:
3335
Py_ssize_t i, j, k, count = 0
@@ -39,8 +41,8 @@ def inner_join(const int64_t[:] left, const int64_t[:] right,
3941

4042
# NA group in location 0
4143

42-
left_sorter, left_count = groupsort_indexer(left, max_groups)
43-
right_sorter, right_count = groupsort_indexer(right, max_groups)
44+
left_sorter, left_count = groupsort_indexer(ensure_int64(left), max_groups)
45+
right_sorter, right_count = groupsort_indexer(ensure_int64(right), max_groups)
4446

4547
with nogil:
4648
# First pass, determine size of result set, do not use the NA group
@@ -78,7 +80,7 @@ def inner_join(const int64_t[:] left, const int64_t[:] right,
7880

7981

8082
@cython.boundscheck(False)
81-
def left_outer_join(const int64_t[:] left, const int64_t[:] right,
83+
def left_outer_join(const intp_t[:] left, const intp_t[:] right,
8284
Py_ssize_t max_groups, bint sort=True):
8385
cdef:
8486
Py_ssize_t i, j, k, count = 0
@@ -91,8 +93,8 @@ def left_outer_join(const int64_t[:] left, const int64_t[:] right,
9193

9294
# NA group in location 0
9395

94-
left_sorter, left_count = groupsort_indexer(left, max_groups)
95-
right_sorter, right_count = groupsort_indexer(right, max_groups)
96+
left_sorter, left_count = groupsort_indexer(ensure_int64(left), max_groups)
97+
right_sorter, right_count = groupsort_indexer(ensure_int64(right), max_groups)
9698

9799
with nogil:
98100
# First pass, determine size of result set, do not use the NA group
@@ -151,7 +153,7 @@ def left_outer_join(const int64_t[:] left, const int64_t[:] right,
151153

152154

153155
@cython.boundscheck(False)
154-
def full_outer_join(const int64_t[:] left, const int64_t[:] right,
156+
def full_outer_join(const intp_t[:] left, const intp_t[:] right,
155157
Py_ssize_t max_groups):
156158
cdef:
157159
Py_ssize_t i, j, k, count = 0
@@ -163,8 +165,8 @@ def full_outer_join(const int64_t[:] left, const int64_t[:] right,
163165

164166
# NA group in location 0
165167

166-
left_sorter, left_count = groupsort_indexer(left, max_groups)
167-
right_sorter, right_count = groupsort_indexer(right, max_groups)
168+
left_sorter, left_count = groupsort_indexer(ensure_int64(left), max_groups)
169+
right_sorter, right_count = groupsort_indexer(ensure_int64(right), max_groups)
168170

169171
with nogil:
170172
# First pass, determine size of result set, do not use the NA group

pandas/core/reshape/merge.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1973,7 +1973,7 @@ def _get_single_indexer(join_key, index, sort: bool = False):
19731973
left_key, right_key, count = _factorize_keys(join_key, index, sort=sort)
19741974

19751975
left_indexer, right_indexer = libjoin.left_outer_join(
1976-
ensure_int64(left_key), ensure_int64(right_key), count, sort=sort
1976+
left_key, right_key, count, sort=sort
19771977
)
19781978

19791979
return left_indexer, right_indexer
@@ -2029,9 +2029,9 @@ def _factorize_keys(
20292029
20302030
Returns
20312031
-------
2032-
array
2032+
np.ndarray[np.intp]
20332033
Left (resp. right if called with `key='right'`) labels, as enumerated type.
2034-
array
2034+
np.ndarray[np.intp]
20352035
Right (resp. left if called with `key='right'`) labels, as enumerated type.
20362036
int
20372037
Number of unique elements in union of left and right labels.
@@ -2117,6 +2117,8 @@ def _factorize_keys(
21172117

21182118
llab = rizer.factorize(lk)
21192119
rlab = rizer.factorize(rk)
2120+
assert llab.dtype == np.intp, llab.dtype
2121+
assert rlab.dtype == np.intp, rlab.dtype
21202122

21212123
count = rizer.get_count()
21222124

@@ -2142,13 +2144,16 @@ def _factorize_keys(
21422144
return llab, rlab, count
21432145

21442146

2145-
def _sort_labels(uniques: np.ndarray, left, right):
2147+
def _sort_labels(
2148+
uniques: np.ndarray, left: np.ndarray, right: np.ndarray
2149+
) -> tuple[np.ndarray, np.ndarray]:
2150+
# Both returned ndarrays are np.intp
21462151

21472152
llength = len(left)
21482153
labels = np.concatenate([left, right])
21492154

21502155
_, new_labels = algos.safe_sort(uniques, labels, na_sentinel=-1)
2151-
new_labels = ensure_int64(new_labels)
2156+
assert new_labels.dtype == np.intp
21522157
new_left, new_right = new_labels[:llength], new_labels[llength:]
21532158

21542159
return new_left, new_right

pandas/tests/libs/test_join.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def test_outer_join_indexer(self, dtype):
4646
tm.assert_numpy_array_equal(rindexer, exp)
4747

4848
def test_cython_left_outer_join(self):
49-
left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.int64)
50-
right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.int64)
49+
left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp)
50+
right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.intp)
5151
max_group = 5
5252

5353
ls, rs = left_outer_join(left, right, max_group)
@@ -70,8 +70,8 @@ def test_cython_left_outer_join(self):
7070
tm.assert_numpy_array_equal(rs, exp_rs, check_dtype=False)
7171

7272
def test_cython_right_outer_join(self):
73-
left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.int64)
74-
right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.int64)
73+
left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp)
74+
right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.intp)
7575
max_group = 5
7676

7777
rs, ls = left_outer_join(right, left, max_group)
@@ -116,8 +116,8 @@ def test_cython_right_outer_join(self):
116116
tm.assert_numpy_array_equal(rs, exp_rs, check_dtype=False)
117117

118118
def test_cython_inner_join(self):
119-
left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.int64)
120-
right = np.array([1, 1, 0, 4, 2, 2, 1, 4], dtype=np.int64)
119+
left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp)
120+
right = np.array([1, 1, 0, 4, 2, 2, 1, 4], dtype=np.intp)
121121
max_group = 5
122122

123123
ls, rs = inner_join(left, right, max_group)
@@ -256,10 +256,10 @@ def test_left_outer_join_bug():
256256
0,
257257
2,
258258
],
259-
dtype=np.int64,
259+
dtype=np.intp,
260260
)
261261

262-
right = np.array([3, 1], dtype=np.int64)
262+
right = np.array([3, 1], dtype=np.intp)
263263
max_groups = 4
264264

265265
lidx, ridx = libjoin.left_outer_join(left, right, max_groups, sort=False)

0 commit comments

Comments
 (0)