18
18
lib ,
19
19
)
20
20
from pandas ._libs .hashtable import unique_label_indices
21
- from pandas ._typing import IndexKeyFunc
21
+ from pandas ._typing import (
22
+ IndexKeyFunc ,
23
+ Shape ,
24
+ )
22
25
23
26
from pandas .core .dtypes .common import (
24
27
ensure_int64 ,
@@ -93,7 +96,7 @@ def get_indexer_indexer(
93
96
return indexer
94
97
95
98
96
- def get_group_index (labels , shape , sort : bool , xnull : bool ):
99
+ def get_group_index (labels , shape : Shape , sort : bool , xnull : bool ):
97
100
"""
98
101
For the particular label_list, gets the offsets into the hypothetical list
99
102
representing the totally ordered cartesian product of all possible label
@@ -108,7 +111,7 @@ def get_group_index(labels, shape, sort: bool, xnull: bool):
108
111
----------
109
112
labels : sequence of arrays
110
113
Integers identifying levels at each location
111
- shape : sequence of ints
114
+ shape : tuple[int, ...]
112
115
Number of unique levels at each location
113
116
sort : bool
114
117
If the ranks of returned ids should match lexical ranks of labels
@@ -134,33 +137,36 @@ def _int64_cut_off(shape) -> int:
134
137
return i
135
138
return len (shape )
136
139
137
- def maybe_lift (lab , size ):
140
+ def maybe_lift (lab , size ) -> tuple [ np . ndarray , int ] :
138
141
# promote nan values (assigned -1 label in lab array)
139
142
# so that all output values are non-negative
140
143
return (lab + 1 , size + 1 ) if (lab == - 1 ).any () else (lab , size )
141
144
142
- labels = map (ensure_int64 , labels )
145
+ labels = [ensure_int64 (x ) for x in labels ]
146
+ lshape = list (shape )
143
147
if not xnull :
144
- labels , shape = map (list , zip (* map (maybe_lift , labels , shape )))
148
+ for i , (lab , size ) in enumerate (zip (labels , shape )):
149
+ lab , size = maybe_lift (lab , size )
150
+ labels [i ] = lab
151
+ lshape [i ] = size
145
152
146
153
labels = list (labels )
147
- shape = list (shape )
148
154
149
155
# Iteratively process all the labels in chunks sized so less
150
156
# than _INT64_MAX unique int ids will be required for each chunk
151
157
while True :
152
158
# how many levels can be done without overflow:
153
- nlev = _int64_cut_off (shape )
159
+ nlev = _int64_cut_off (lshape )
154
160
155
161
# compute flat ids for the first `nlev` levels
156
- stride = np .prod (shape [1 :nlev ], dtype = "i8" )
162
+ stride = np .prod (lshape [1 :nlev ], dtype = "i8" )
157
163
out = stride * labels [0 ].astype ("i8" , subok = False , copy = False )
158
164
159
165
for i in range (1 , nlev ):
160
- if shape [i ] == 0 :
161
- stride = 0
166
+ if lshape [i ] == 0 :
167
+ stride = np . int64 ( 0 )
162
168
else :
163
- stride //= shape [i ]
169
+ stride //= lshape [i ]
164
170
out += labels [i ] * stride
165
171
166
172
if xnull : # exclude nulls
@@ -169,20 +175,20 @@ def maybe_lift(lab, size):
169
175
mask |= lab == - 1
170
176
out [mask ] = - 1
171
177
172
- if nlev == len (shape ): # all levels done!
178
+ if nlev == len (lshape ): # all levels done!
173
179
break
174
180
175
181
# compress what has been done so far in order to avoid overflow
176
182
# to retain lexical ranks, obs_ids should be sorted
177
183
comp_ids , obs_ids = compress_group_index (out , sort = sort )
178
184
179
185
labels = [comp_ids ] + labels [nlev :]
180
- shape = [len (obs_ids )] + shape [nlev :]
186
+ lshape = [len (obs_ids )] + lshape [nlev :]
181
187
182
188
return out
183
189
184
190
185
- def get_compressed_ids (labels , sizes ) -> tuple [np .ndarray , np .ndarray ]:
191
+ def get_compressed_ids (labels , sizes : Shape ) -> tuple [np .ndarray , np .ndarray ]:
186
192
"""
187
193
Group_index is offsets into cartesian product of all possible labels. This
188
194
space can be huge, so this function compresses it, by computing offsets
@@ -191,7 +197,7 @@ def get_compressed_ids(labels, sizes) -> tuple[np.ndarray, np.ndarray]:
191
197
Parameters
192
198
----------
193
199
labels : list of label arrays
194
- sizes : list of size of the levels
200
+ sizes : tuple[int] of size of the levels
195
201
196
202
Returns
197
203
-------
@@ -252,12 +258,11 @@ def decons_obs_group_ids(comp_ids: np.ndarray, obs_ids, shape, labels, xnull: bo
252
258
return out if xnull or not lift .any () else [x - y for x , y in zip (out , lift )]
253
259
254
260
# TODO: unique_label_indices only used here, should take ndarray[np.intp]
255
- i = unique_label_indices (ensure_int64 (comp_ids ))
256
- i8copy = lambda a : a .astype ("i8" , subok = False , copy = True )
257
- return [i8copy (lab [i ]) for lab in labels ]
261
+ indexer = unique_label_indices (ensure_int64 (comp_ids ))
262
+ return [lab [indexer ].astype (np .intp , subok = False , copy = True ) for lab in labels ]
258
263
259
264
260
- def indexer_from_factorized (labels , shape , compress : bool = True ) -> np .ndarray :
265
+ def indexer_from_factorized (labels , shape : Shape , compress : bool = True ) -> np .ndarray :
261
266
# returned ndarray is np.intp
262
267
ids = get_group_index (labels , shape , sort = True , xnull = False )
263
268
@@ -334,7 +339,7 @@ def lexsort_indexer(
334
339
shape .append (n )
335
340
labels .append (codes )
336
341
337
- return indexer_from_factorized (labels , shape )
342
+ return indexer_from_factorized (labels , tuple ( shape ) )
338
343
339
344
340
345
def nargsort (
@@ -576,7 +581,7 @@ def get_indexer_dict(
576
581
"""
577
582
shape = [len (x ) for x in keys ]
578
583
579
- group_index = get_group_index (label_list , shape , sort = True , xnull = True )
584
+ group_index = get_group_index (label_list , tuple ( shape ) , sort = True , xnull = True )
580
585
if np .all (group_index == - 1 ):
581
586
# Short-circuit, lib.indices_fast will return the same
582
587
return {}
0 commit comments