Skip to content

Commit afe934b

Browse files
committed
Avoid copy of flipped A matrices in GEMV
1 parent b2365e0 commit afe934b

File tree

2 files changed

+71
-19
lines changed

2 files changed

+71
-19
lines changed

pytensor/tensor/blas_c.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
423423
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424424
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425425
426+
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
426427
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
427428
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
428429
// gemv expects pointers to the beginning of memory arrays,
@@ -435,17 +436,25 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
435436
436437
if (NA0 * NA1)
437438
{
438-
// If A is neither C- nor F-contiguous, we make a copy.
439-
// TODO:
440-
// - if one stride is equal to "- elemsize", we can still call
441-
// gemv on reversed matrix and vectors
442-
// - if the copy is too long, maybe call vector/vector dot on
443-
// each row instead
444-
if ((PyArray_STRIDES(%(A)s)[0] < 0)
445-
|| (PyArray_STRIDES(%(A)s)[1] < 0)
446-
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
447-
&& (PyArray_STRIDES(%(A)s)[1] != elemsize)))
439+
if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
448440
{
441+
// We can treat the array A as C-or F-contiguous by changing the order of iteration
442+
if (SA0 < 0){
443+
A_data += (NA0 -1) * SA0; // Jump to first row
444+
SA0 = -SA0; // Iterate over rows in reverse
445+
Sz = -Sz; // Iterate over y in reverse
446+
}
447+
if (SA1 < 0){
448+
A_data += (NA1 -1) * SA1; // Jump to first column
449+
SA1 = -SA1; // Iterate over columns in reverse
450+
Sx = -Sx; // Iterate over x in reverse
451+
}
452+
} else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
453+
{
454+
// Array isn't contiguous, we have to make a copy
455+
// - if the copy is too long, maybe call vector/vector dot on
456+
// each row instead
457+
// printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1);
449458
npy_intp dims[2];
450459
dims[0] = NA0;
451460
dims[1] = NA1;
@@ -458,16 +467,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
458467
%(A)s = A_copy;
459468
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
460469
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
470+
A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
461471
}
462472
463-
if (PyArray_STRIDES(%(A)s)[0] == elemsize)
473+
if (SA0 == 1)
464474
{
465475
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
466476
{
467477
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
468478
sgemv_(&NOTRANS, &NA0, &NA1,
469479
&alpha,
470-
(float*)(PyArray_DATA(%(A)s)), &SA1,
480+
(float*)(A_data), &SA1,
471481
(float*)x_data, &Sx,
472482
&fbeta,
473483
(float*)z_data, &Sz);
@@ -477,7 +487,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
477487
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478488
dgemv_(&NOTRANS, &NA0, &NA1,
479489
&alpha,
480-
(double*)(PyArray_DATA(%(A)s)), &SA1,
490+
(double*)(A_data), &SA1,
481491
(double*)x_data, &Sx,
482492
&dbeta,
483493
(double*)z_data, &Sz);
@@ -489,7 +499,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
489499
%(fail)s
490500
}
491501
}
492-
else if (PyArray_STRIDES(%(A)s)[1] == elemsize)
502+
else if (SA1 == 1)
493503
{
494504
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
495505
{
@@ -506,14 +516,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
506516
z_data[0] = 0.f;
507517
}
508518
z_data[0] += alpha*sdot_(&NA1,
509-
(float*)(PyArray_DATA(%(A)s)), &SA1,
519+
(float*)(A_data), &SA1,
510520
(float*)x_data, &Sx);
511521
}
512522
else
513523
{
514524
sgemv_(&TRANS, &NA1, &NA0,
515525
&alpha,
516-
(float*)(PyArray_DATA(%(A)s)), &SA0,
526+
(float*)(A_data), &SA0,
517527
(float*)x_data, &Sx,
518528
&fbeta,
519529
(float*)z_data, &Sz);
@@ -534,14 +544,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
534544
z_data[0] = 0.;
535545
}
536546
z_data[0] += alpha*ddot_(&NA1,
537-
(double*)(PyArray_DATA(%(A)s)), &SA1,
547+
(double*)(A_data), &SA1,
538548
(double*)x_data, &Sx);
539549
}
540550
else
541551
{
542552
dgemv_(&TRANS, &NA1, &NA0,
543553
&alpha,
544-
(double*)(PyArray_DATA(%(A)s)), &SA0,
554+
(double*)(A_data), &SA0,
545555
(double*)x_data, &Sx,
546556
&dbeta,
547557
(double*)z_data, &Sz);
@@ -603,7 +613,7 @@ def c_code(self, node, name, inp, out, sub):
603613
return code
604614

605615
def c_code_cache_version(self):
606-
return (14, blas_header_version(), check_force_gemv_init())
616+
return (15, blas_header_version(), check_force_gemv_init())
607617

608618

609619
cgemv_inplace = CGemv(inplace=True)

tests/tensor/test_blas_c.py

+42
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,45 @@ class TestSdotNoFlags(TestCGemvNoFlags):
411411

412412
class TestBlasStridesC(TestBlasStrides):
413413
mode = mode_blas_opt
414+
415+
416+
@pytest.mark.parametrize(
417+
"neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"]
418+
)
419+
@pytest.mark.parametrize(
420+
"neg_stride0", (True, False), ids=["neg_stride0", "pos_stride0"]
421+
)
422+
@pytest.mark.parametrize("F_layout", (True, False), ids=["F_layout", "C_layout"])
423+
def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmark):
424+
A = pt.matrix("A", shape=(512, 512))
425+
x = pt.vector("x", shape=(A.type.shape[-1],))
426+
y = pt.vector("y", shape=(A.type.shape[0],))
427+
428+
out = CGemv(inplace=False)(
429+
y,
430+
1.0,
431+
A,
432+
x,
433+
1.0,
434+
)
435+
fn = pytensor.function([A, x, y], out, trust_input=True)
436+
437+
rng = np.random.default_rng(430)
438+
test_A = rng.normal(size=A.type.shape)
439+
test_x = rng.normal(size=x.type.shape)
440+
test_y = rng.normal(size=y.type.shape)
441+
442+
if F_layout:
443+
test_A = test_A.T
444+
if neg_stride0:
445+
test_A = test_A[::-1]
446+
if neg_stride1:
447+
test_A = test_A[:, ::-1]
448+
assert (test_A.strides[0] < 0) == neg_stride0
449+
assert (test_A.strides[1] < 0) == neg_stride1
450+
451+
# Check result is correct by using a copy of A with positive strides
452+
res = fn(test_A, test_x, test_y)
453+
np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y))
454+
455+
benchmark(fn, test_A, test_x, test_y)

0 commit comments

Comments
 (0)