Skip to content

Commit 2f182a9

Browse files
committed
Support parallel calculation of nancorr
1 parent e2bf1ff commit 2f182a9

File tree

3 files changed

+84
-40
lines changed

3 files changed

+84
-40
lines changed

pandas/_libs/algos.pyx

+75-34
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
import cython
44
from cython import Py_ssize_t
5+
from cython.parallel import prange
56

67
from libc.stdlib cimport malloc, free
78
from libc.string cimport memmove
89
from libc.math cimport fabs, sqrt
10+
from cpython cimport bool
911

1012
import numpy as np
1113
cimport numpy as cnp
@@ -230,14 +232,15 @@ def kth_smallest(numeric[:] a, Py_ssize_t k) -> numeric:
230232

231233
@cython.boundscheck(False)
232234
@cython.wraparound(False)
233-
def nancorr(ndarray[float64_t, ndim=2] mat, bint cov=0, minp=None):
235+
def nancorr(float64_t[:, :] mat, bint cov=0, minp=None, bool parallel=False):
234236
cdef:
235237
Py_ssize_t i, j, xi, yi, N, K
236238
bint minpv
237-
ndarray[float64_t, ndim=2] result
238-
ndarray[uint8_t, ndim=2] mask
239+
float64_t[:, :] result
240+
uint8_t[:, :] mask
239241
int64_t nobs = 0
240242
float64_t vx, vy, sumx, sumy, sumxx, sumyy, meanx, meany, divisor
243+
int64_t blah = 0
241244

242245
N, K = (<object>mat).shape
243246

@@ -249,44 +252,82 @@ def nancorr(ndarray[float64_t, ndim=2] mat, bint cov=0, minp=None):
249252
result = np.empty((K, K), dtype=np.float64)
250253
mask = np.isfinite(mat).view(np.uint8)
251254

252-
with nogil:
253-
for xi in range(K):
254-
for yi in range(xi + 1):
255-
nobs = sumxx = sumyy = sumx = sumy = 0
256-
for i in range(N):
257-
if mask[i, xi] and mask[i, yi]:
258-
vx = mat[i, xi]
259-
vy = mat[i, yi]
260-
nobs += 1
261-
sumx += vx
262-
sumy += vy
255+
if parallel:
256+
with nogil:
257+
for xi in prange(K, schedule='dynamic'):
258+
nancorr_single_row(mat, N, K, result, xi, mask, minpv, cov)
259+
else:
260+
with nogil:
261+
for xi in range(K):
262+
nancorr_single_row(mat, N, K, result, xi, mask, minpv, cov)
263263

264-
if nobs < minpv:
265-
result[xi, yi] = result[yi, xi] = NaN
266-
else:
267-
meanx = sumx / nobs
268-
meany = sumy / nobs
264+
return np.asarray(result)
269265

270-
# now the cov numerator
271-
sumx = 0
272266

273-
for i in range(N):
274-
if mask[i, xi] and mask[i, yi]:
275-
vx = mat[i, xi] - meanx
276-
vy = mat[i, yi] - meany
267+
@cython.boundscheck(False)
268+
@cython.wraparound(False)
269+
cdef void nancorr_single_row(float64_t[:, :] mat,
270+
Py_ssize_t N,
271+
Py_ssize_t K,
272+
float64_t[:, :] result,
273+
Py_ssize_t xi,
274+
uint8_t[:, :] mask,
275+
bint minpv,
276+
bint cov=0) nogil:
277+
for yi in range(xi + 1):
278+
nancorr_single(mat, N, K, result, xi, yi, mask, minpv, cov)
277279

278-
sumx += vx * vy
279-
sumxx += vx * vx
280-
sumyy += vy * vy
281280

282-
divisor = (nobs - 1.0) if cov else sqrt(sumxx * sumyy)
281+
@cython.boundscheck(False)
282+
@cython.wraparound(False)
283+
cdef void nancorr_single(float64_t[:, :] mat,
284+
Py_ssize_t N,
285+
Py_ssize_t K,
286+
float64_t[:, :] result,
287+
Py_ssize_t xi,
288+
Py_ssize_t yi,
289+
uint8_t[:, :] mask,
290+
bint minpv,
291+
bint cov=0) nogil:
292+
cdef:
293+
Py_ssize_t i, j
294+
int64_t nobs = 0
295+
float64_t vx, vy, sumx, sumy, sumxx, sumyy, meanx, meany, divisor
283296

284-
if divisor != 0:
285-
result[xi, yi] = result[yi, xi] = sumx / divisor
286-
else:
287-
result[xi, yi] = result[yi, xi] = NaN
297+
nobs = sumxx = sumyy = sumx = sumy = 0
298+
for i in range(N):
299+
if mask[i, xi] and mask[i, yi]:
300+
vx = mat[i, xi]
301+
vy = mat[i, yi]
302+
nobs += 1
303+
sumx += vx
304+
sumy += vy
305+
306+
if nobs < minpv:
307+
result[xi, yi] = result[yi, xi] = NaN
308+
else:
309+
meanx = sumx / nobs
310+
meany = sumy / nobs
311+
312+
# now the cov numerator
313+
sumx = 0
314+
315+
for i in range(N):
316+
if mask[i, xi] and mask[i, yi]:
317+
vx = mat[i, xi] - meanx
318+
vy = mat[i, yi] - meany
319+
320+
sumx += vx * vy
321+
sumxx += vx * vx
322+
sumyy += vy * vy
323+
324+
divisor = (nobs - 1.0) if cov else sqrt(sumxx * sumyy)
325+
326+
if divisor != 0:
327+
result[xi, yi] = result[yi, xi] = sumx / divisor
328+
else:
329+
result[xi, yi] = result[yi, xi] = NaN
288330

289-
return result
290331

291332
# ----------------------------------------------------------------------
292333
# Pairwise Spearman correlation

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6996,7 +6996,7 @@ def corr(self, method='pearson', min_periods=1):
69966996
mat = numeric_df.values
69976997

69986998
if method == 'pearson':
6999-
correl = libalgos.nancorr(ensure_float64(mat), minp=min_periods)
6999+
correl = libalgos.nancorr(ensure_float64(mat), minp=min_periods, parallel=True)
70007000
elif method == 'spearman':
70017001
correl = libalgos.nancorr_spearman(ensure_float64(mat),
70027002
minp=min_periods)

setup.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
from os.path import join as pjoin
1111

12+
import numpy
1213
import pkg_resources
1314
import platform
1415
from distutils.sysconfig import get_config_var
@@ -677,10 +678,11 @@ def srcpath(name=None, suffix='.pyx', subdir='src'):
677678
obj = Extension('pandas.{name}'.format(name=name),
678679
sources=sources,
679680
depends=data.get('depends', []),
680-
include_dirs=include,
681+
include_dirs=include + [numpy.get_include()],
681682
language=data.get('language', 'c'),
682683
define_macros=data.get('macros', macros),
683-
extra_compile_args=extra_compile_args)
684+
extra_compile_args=['-fopenmp'] + extra_compile_args,
685+
extra_link_args=['-fopenmp'])
684686

685687
extensions.append(obj)
686688

@@ -704,12 +706,13 @@ def srcpath(name=None, suffix='.pyx', subdir='src'):
704706
np_datetime_sources),
705707
include_dirs=['pandas/_libs/src/ujson/python',
706708
'pandas/_libs/src/ujson/lib',
707-
'pandas/_libs/src/datetime'],
708-
extra_compile_args=(['-D_GNU_SOURCE'] +
709+
'pandas/_libs/src/datetime',
710+
numpy.get_include()],
711+
extra_compile_args=(['-D_GNU_SOURCE', '-fopenmp'] +
709712
extra_compile_args),
713+
extra_link_args=['-fopenmp'],
710714
define_macros=macros)
711715

712-
713716
extensions.append(ujson_ext)
714717

715718
# ----------------------------------------------------------------------

0 commit comments

Comments
 (0)