Skip to content

Commit 17b530a

Browse files
jbrockmendelproost
authored andcommitted
REF: avoid getattr pattern for diff_2d; use fused types (pandas-dev#29120)
1 parent 1d68fb9 commit 17b530a

File tree

2 files changed

+25
-31
lines changed

2 files changed

+25
-31
lines changed

pandas/_libs/algos_common_helper.pxi.in

+23-22
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,34 @@ Template for each `dtype` helper function using 1-d template
44
WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
55
"""
66

7-
{{py:
8-
9-
# name, c_type, dest_type
10-
dtypes = [('float64', 'float64_t', 'float64_t'),
11-
('float32', 'float32_t', 'float32_t'),
12-
('int8', 'int8_t', 'float32_t'),
13-
('int16', 'int16_t', 'float32_t'),
14-
('int32', 'int32_t', 'float64_t'),
15-
('int64', 'int64_t', 'float64_t')]
16-
17-
def get_dispatch(dtypes):
18-
19-
for name, c_type, dest_type, in dtypes:
20-
yield name, c_type, dest_type
7+
ctypedef fused diff_t:
8+
float64_t
9+
float32_t
10+
int8_t
11+
int16_t
12+
int32_t
13+
int64_t
2114

22-
}}
23-
24-
{{for name, c_type, dest_type
25-
in get_dispatch(dtypes)}}
15+
ctypedef fused out_t:
16+
float32_t
17+
float64_t
2618

2719

2820
@cython.boundscheck(False)
2921
@cython.wraparound(False)
30-
def diff_2d_{{name}}(ndarray[{{c_type}}, ndim=2] arr,
31-
ndarray[{{dest_type}}, ndim=2] out,
32-
Py_ssize_t periods, int axis):
22+
def diff_2d(ndarray[diff_t, ndim=2] arr,
23+
ndarray[out_t, ndim=2] out,
24+
Py_ssize_t periods, int axis):
25+
26+
# Disable for unsupported dtype combinations,
27+
# see https://github.com/cython/cython/issues/2646
28+
if out_t is float32_t:
29+
if not (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t):
30+
raise NotImplementedError
31+
else:
32+
if (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t):
33+
raise NotImplementedError
34+
3335
cdef:
3436
Py_ssize_t i, j, sx, sy
3537

@@ -69,7 +71,6 @@ def diff_2d_{{name}}(ndarray[{{c_type}}, ndim=2] arr,
6971
for j in range(start, stop):
7072
out[i, j] = arr[i, j] - arr[i, j - periods]
7173

72-
{{endfor}}
7374

7475
# ----------------------------------------------------------------------
7576
# ensure_dtype

pandas/core/algorithms.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -1850,14 +1850,7 @@ def searchsorted(arr, value, side="left", sorter=None):
18501850
# diff #
18511851
# ---- #
18521852

1853-
_diff_special = {
1854-
"float64": algos.diff_2d_float64,
1855-
"float32": algos.diff_2d_float32,
1856-
"int64": algos.diff_2d_int64,
1857-
"int32": algos.diff_2d_int32,
1858-
"int16": algos.diff_2d_int16,
1859-
"int8": algos.diff_2d_int8,
1860-
}
1853+
_diff_special = {"float64", "float32", "int64", "int32", "int16", "int8"}
18611854

18621855

18631856
def diff(arr, n: int, axis: int = 0):
@@ -1905,7 +1898,7 @@ def diff(arr, n: int, axis: int = 0):
19051898
out_arr[tuple(na_indexer)] = na
19061899

19071900
if arr.ndim == 2 and arr.dtype.name in _diff_special:
1908-
f = _diff_special[arr.dtype.name]
1901+
f = algos.diff_2d
19091902
f(arr, out_arr, n, axis)
19101903
else:
19111904
# To keep mypy happy, _res_indexer is a list while res_indexer is

0 commit comments

Comments
 (0)