Skip to content

Commit 601c43c

Browse files
authored
TYP: core.sorting (pandas-dev#41285)
1 parent 30dd641 commit 601c43c

File tree

4 files changed

+40
-28
lines changed

4 files changed

+40
-28
lines changed

pandas/core/frame.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -6146,7 +6146,7 @@ def duplicated(
61466146
if self.empty:
61476147
return self._constructor_sliced(dtype=bool)
61486148

6149-
def f(vals):
6149+
def f(vals) -> tuple[np.ndarray, int]:
61506150
labels, shape = algorithms.factorize(vals, size_hint=len(self))
61516151
return labels.astype("i8", copy=False), len(shape)
61526152

@@ -6173,7 +6173,14 @@ def f(vals):
61736173
vals = (col.values for name, col in self.items() if name in subset)
61746174
labels, shape = map(list, zip(*map(f, vals)))
61756175

6176-
ids = get_group_index(labels, shape, sort=False, xnull=False)
6176+
ids = get_group_index(
6177+
labels,
6178+
# error: Argument 1 to "tuple" has incompatible type "List[_T]";
6179+
# expected "Iterable[int]"
6180+
tuple(shape), # type: ignore[arg-type]
6181+
sort=False,
6182+
xnull=False,
6183+
)
61776184
result = self._constructor_sliced(duplicated_int64(ids, keep), index=self.index)
61786185
return result.__finalize__(self, method="duplicated")
61796186

pandas/core/indexes/multi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1611,7 +1611,7 @@ def _inferred_type_levels(self) -> list[str]:
16111611

16121612
@doc(Index.duplicated)
16131613
def duplicated(self, keep="first") -> np.ndarray:
1614-
shape = map(len, self.levels)
1614+
shape = tuple(len(lev) for lev in self.levels)
16151615
ids = get_group_index(self.codes, shape, sort=False, xnull=False)
16161616

16171617
return duplicated_int64(ids, keep)

pandas/core/reshape/reshape.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _indexer_and_to_sort(
142142
codes = list(self.index.codes)
143143
levs = list(self.index.levels)
144144
to_sort = codes[:v] + codes[v + 1 :] + [codes[v]]
145-
sizes = [len(x) for x in levs[:v] + levs[v + 1 :] + [levs[v]]]
145+
sizes = tuple(len(x) for x in levs[:v] + levs[v + 1 :] + [levs[v]])
146146

147147
comp_index, obs_ids = get_compressed_ids(to_sort, sizes)
148148
ngroups = len(obs_ids)
@@ -166,7 +166,7 @@ def _make_selectors(self):
166166

167167
# make the mask
168168
remaining_labels = self.sorted_labels[:-1]
169-
level_sizes = [len(x) for x in new_levels]
169+
level_sizes = tuple(len(x) for x in new_levels)
170170

171171
comp_index, obs_ids = get_compressed_ids(remaining_labels, level_sizes)
172172
ngroups = len(obs_ids)
@@ -353,7 +353,7 @@ def _unstack_multiple(data, clocs, fill_value=None):
353353
rcodes = [index.codes[i] for i in rlocs]
354354
rnames = [index.names[i] for i in rlocs]
355355

356-
shape = [len(x) for x in clevels]
356+
shape = tuple(len(x) for x in clevels)
357357
group_index = get_group_index(ccodes, shape, sort=False, xnull=False)
358358

359359
comp_ids, obs_ids = compress_group_index(group_index, sort=False)

pandas/core/sorting.py

+27-22
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
lib,
1919
)
2020
from pandas._libs.hashtable import unique_label_indices
21-
from pandas._typing import IndexKeyFunc
21+
from pandas._typing import (
22+
IndexKeyFunc,
23+
Shape,
24+
)
2225

2326
from pandas.core.dtypes.common import (
2427
ensure_int64,
@@ -93,7 +96,7 @@ def get_indexer_indexer(
9396
return indexer
9497

9598

96-
def get_group_index(labels, shape, sort: bool, xnull: bool):
99+
def get_group_index(labels, shape: Shape, sort: bool, xnull: bool):
97100
"""
98101
For the particular label_list, gets the offsets into the hypothetical list
99102
representing the totally ordered cartesian product of all possible label
@@ -108,7 +111,7 @@ def get_group_index(labels, shape, sort: bool, xnull: bool):
108111
----------
109112
labels : sequence of arrays
110113
Integers identifying levels at each location
111-
shape : sequence of ints
114+
shape : tuple[int, ...]
112115
Number of unique levels at each location
113116
sort : bool
114117
If the ranks of returned ids should match lexical ranks of labels
@@ -134,33 +137,36 @@ def _int64_cut_off(shape) -> int:
134137
return i
135138
return len(shape)
136139

137-
def maybe_lift(lab, size):
140+
def maybe_lift(lab, size) -> tuple[np.ndarray, int]:
138141
# promote nan values (assigned -1 label in lab array)
139142
# so that all output values are non-negative
140143
return (lab + 1, size + 1) if (lab == -1).any() else (lab, size)
141144

142-
labels = map(ensure_int64, labels)
145+
labels = [ensure_int64(x) for x in labels]
146+
lshape = list(shape)
143147
if not xnull:
144-
labels, shape = map(list, zip(*map(maybe_lift, labels, shape)))
148+
for i, (lab, size) in enumerate(zip(labels, shape)):
149+
lab, size = maybe_lift(lab, size)
150+
labels[i] = lab
151+
lshape[i] = size
145152

146153
labels = list(labels)
147-
shape = list(shape)
148154

149155
# Iteratively process all the labels in chunks sized so less
150156
# than _INT64_MAX unique int ids will be required for each chunk
151157
while True:
152158
# how many levels can be done without overflow:
153-
nlev = _int64_cut_off(shape)
159+
nlev = _int64_cut_off(lshape)
154160

155161
# compute flat ids for the first `nlev` levels
156-
stride = np.prod(shape[1:nlev], dtype="i8")
162+
stride = np.prod(lshape[1:nlev], dtype="i8")
157163
out = stride * labels[0].astype("i8", subok=False, copy=False)
158164

159165
for i in range(1, nlev):
160-
if shape[i] == 0:
161-
stride = 0
166+
if lshape[i] == 0:
167+
stride = np.int64(0)
162168
else:
163-
stride //= shape[i]
169+
stride //= lshape[i]
164170
out += labels[i] * stride
165171

166172
if xnull: # exclude nulls
@@ -169,20 +175,20 @@ def maybe_lift(lab, size):
169175
mask |= lab == -1
170176
out[mask] = -1
171177

172-
if nlev == len(shape): # all levels done!
178+
if nlev == len(lshape): # all levels done!
173179
break
174180

175181
# compress what has been done so far in order to avoid overflow
176182
# to retain lexical ranks, obs_ids should be sorted
177183
comp_ids, obs_ids = compress_group_index(out, sort=sort)
178184

179185
labels = [comp_ids] + labels[nlev:]
180-
shape = [len(obs_ids)] + shape[nlev:]
186+
lshape = [len(obs_ids)] + lshape[nlev:]
181187

182188
return out
183189

184190

185-
def get_compressed_ids(labels, sizes) -> tuple[np.ndarray, np.ndarray]:
191+
def get_compressed_ids(labels, sizes: Shape) -> tuple[np.ndarray, np.ndarray]:
186192
"""
187193
Group_index is offsets into cartesian product of all possible labels. This
188194
space can be huge, so this function compresses it, by computing offsets
@@ -191,7 +197,7 @@ def get_compressed_ids(labels, sizes) -> tuple[np.ndarray, np.ndarray]:
191197
Parameters
192198
----------
193199
labels : list of label arrays
194-
sizes : list of size of the levels
200+
sizes : tuple[int] of size of the levels
195201
196202
Returns
197203
-------
@@ -252,12 +258,11 @@ def decons_obs_group_ids(comp_ids: np.ndarray, obs_ids, shape, labels, xnull: bo
252258
return out if xnull or not lift.any() else [x - y for x, y in zip(out, lift)]
253259

254260
# TODO: unique_label_indices only used here, should take ndarray[np.intp]
255-
i = unique_label_indices(ensure_int64(comp_ids))
256-
i8copy = lambda a: a.astype("i8", subok=False, copy=True)
257-
return [i8copy(lab[i]) for lab in labels]
261+
indexer = unique_label_indices(ensure_int64(comp_ids))
262+
return [lab[indexer].astype(np.intp, subok=False, copy=True) for lab in labels]
258263

259264

260-
def indexer_from_factorized(labels, shape, compress: bool = True) -> np.ndarray:
265+
def indexer_from_factorized(labels, shape: Shape, compress: bool = True) -> np.ndarray:
261266
# returned ndarray is np.intp
262267
ids = get_group_index(labels, shape, sort=True, xnull=False)
263268

@@ -334,7 +339,7 @@ def lexsort_indexer(
334339
shape.append(n)
335340
labels.append(codes)
336341

337-
return indexer_from_factorized(labels, shape)
342+
return indexer_from_factorized(labels, tuple(shape))
338343

339344

340345
def nargsort(
@@ -576,7 +581,7 @@ def get_indexer_dict(
576581
"""
577582
shape = [len(x) for x in keys]
578583

579-
group_index = get_group_index(label_list, shape, sort=True, xnull=True)
584+
group_index = get_group_index(label_list, tuple(shape), sort=True, xnull=True)
580585
if np.all(group_index == -1):
581586
# Short-circuit, lib.indices_fast will return the same
582587
return {}

0 commit comments

Comments
 (0)