Skip to content

Commit a5efcce

Browse files
jbrockmendelproost
authored andcommitted
stronger typing in libreduction (pandas-dev#29502)
1 parent dc076fa commit a5efcce

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

pandas/_libs/reduction.pyx

+12-18
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ cdef class Reducer:
4343
"""
4444
cdef:
4545
Py_ssize_t increment, chunksize, nresults
46-
object arr, dummy, f, labels, typ, ityp, index
46+
object dummy, f, labels, typ, ityp, index
47+
ndarray arr
4748

48-
def __init__(self, object arr, object f, axis=1, dummy=None, labels=None):
49-
n, k = arr.shape
49+
def __init__(self, ndarray arr, object f, axis=1, dummy=None, labels=None):
50+
n, k = (<object>arr).shape
5051

5152
if axis == 0:
5253
if not arr.flags.f_contiguous:
@@ -102,7 +103,7 @@ cdef class Reducer:
102103
ndarray arr, result, chunk
103104
Py_ssize_t i, incr
104105
flatiter it
105-
bint has_labels, has_ndarray_labels
106+
bint has_labels
106107
object res, name, labels, index
107108
object cached_typ = None
108109

@@ -112,17 +113,13 @@ cdef class Reducer:
112113
chunk.data = arr.data
113114
labels = self.labels
114115
has_labels = labels is not None
115-
has_ndarray_labels = util.is_array(labels)
116116
has_index = self.index is not None
117117
incr = self.increment
118118

119119
try:
120120
for i in range(self.nresults):
121121

122-
if has_ndarray_labels:
123-
name = labels[i]
124-
elif has_labels:
125-
# labels is an ExtensionArray
122+
if has_labels:
126123
name = labels[i]
127124
else:
128125
name = None
@@ -206,11 +203,10 @@ cdef class SeriesBinGrouper(_BaseGrouper):
206203
Py_ssize_t nresults, ngroups
207204

208205
cdef public:
209-
object arr, index, dummy_arr, dummy_index
206+
ndarray arr, index, dummy_arr, dummy_index
210207
object values, f, bins, typ, ityp, name
211208

212209
def __init__(self, object series, object f, object bins, object dummy):
213-
n = len(series)
214210

215211
assert dummy is not None # always obj[:0]
216212

@@ -317,12 +313,11 @@ cdef class SeriesGrouper(_BaseGrouper):
317313
Py_ssize_t nresults, ngroups
318314

319315
cdef public:
320-
object arr, index, dummy_arr, dummy_index
316+
ndarray arr, index, dummy_arr, dummy_index
321317
object f, labels, values, typ, ityp, name
322318

323319
def __init__(self, object series, object f, object labels,
324320
Py_ssize_t ngroups, object dummy):
325-
n = len(series)
326321

327322
# in practice we always pass either obj[:0] or the
328323
# safer obj._get_values(slice(None, 0))
@@ -446,14 +441,13 @@ cdef class Slider:
446441
Py_ssize_t stride, orig_len, orig_stride
447442
char *orig_data
448443

449-
def __init__(self, object values, object buf):
450-
assert (values.ndim == 1)
444+
def __init__(self, ndarray values, ndarray buf):
445+
assert values.ndim == 1
446+
assert values.dtype == buf.dtype
451447

452-
if util.is_array(values) and not values.flags.contiguous:
453-
# e.g. Categorical has no `flags` attribute
448+
if not values.flags.contiguous:
454449
values = values.copy()
455450

456-
assert (values.dtype == buf.dtype)
457451
self.values = values
458452
self.buf = buf
459453
self.stride = values.strides[0]

0 commit comments

Comments
 (0)