@@ -224,7 +224,9 @@ def is_int64_overflow_possible(shape: Shape) -> bool:
224
224
return the_prod >= lib .i8max
225
225
226
226
227
- def decons_group_index (comp_labels , shape : Shape ):
227
+ def _decons_group_index (
228
+ comp_labels : npt .NDArray [np .intp ], shape : Shape
229
+ ) -> list [npt .NDArray [np .intp ]]:
228
230
# reconstruct labels
229
231
if is_int64_overflow_possible (shape ):
230
232
# at some point group indices are factorized,
@@ -233,7 +235,7 @@ def decons_group_index(comp_labels, shape: Shape):
233
235
234
236
label_list = []
235
237
factor = 1
236
- y = 0
238
+ y = np . array ( 0 )
237
239
x = comp_labels
238
240
for i in reversed (range (len (shape ))):
239
241
labels = (x - y ) % (factor * shape [i ]) // factor
@@ -245,24 +247,32 @@ def decons_group_index(comp_labels, shape: Shape):
245
247
246
248
247
249
def decons_obs_group_ids (
248
- comp_ids : npt .NDArray [np .intp ], obs_ids , shape : Shape , labels , xnull : bool
249
- ):
250
+ comp_ids : npt .NDArray [np .intp ],
251
+ obs_ids : npt .NDArray [np .intp ],
252
+ shape : Shape ,
253
+ labels : Sequence [npt .NDArray [np .signedinteger ]],
254
+ xnull : bool ,
255
+ ) -> list [npt .NDArray [np .intp ]]:
250
256
"""
251
257
Reconstruct labels from observed group ids.
252
258
253
259
Parameters
254
260
----------
255
261
comp_ids : np.ndarray[np.intp]
262
+ obs_ids: np.ndarray[np.intp]
263
+ shape : tuple[int]
264
+ labels : Sequence[np.ndarray[np.signedinteger]]
256
265
xnull : bool
257
266
If nulls are excluded; i.e. -1 labels are passed through.
258
267
"""
259
268
if not xnull :
260
- lift = np .fromiter (((a == - 1 ).any () for a in labels ), dtype = "i8" )
261
- shape = np .asarray (shape , dtype = "i8" ) + lift
269
+ lift = np .fromiter (((a == - 1 ).any () for a in labels ), dtype = np .intp )
270
+ arr_shape = np .asarray (shape , dtype = np .intp ) + lift
271
+ shape = tuple (arr_shape )
262
272
263
273
if not is_int64_overflow_possible (shape ):
264
274
# obs ids are deconstructable! take the fast route!
265
- out = decons_group_index (obs_ids , shape )
275
+ out = _decons_group_index (obs_ids , shape )
266
276
return out if xnull or not lift .any () else [x - y for x , y in zip (out , lift )]
267
277
268
278
indexer = unique_label_indices (comp_ids )
0 commit comments