@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
423
423
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424
424
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425
425
426
+ dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
426
427
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
427
428
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
428
429
// gemv expects pointers to the beginning of memory arrays,
@@ -435,17 +436,25 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
435
436
436
437
if (NA0 * NA1)
437
438
{
438
- // If A is neither C- nor F-contiguous, we make a copy.
439
- // TODO:
440
- // - if one stride is equal to "- elemsize", we can still call
441
- // gemv on reversed matrix and vectors
442
- // - if the copy is too long, maybe call vector/vector dot on
443
- // each row instead
444
- if ((PyArray_STRIDES(%(A)s)[0] < 0)
445
- || (PyArray_STRIDES(%(A)s)[1] < 0)
446
- || ((PyArray_STRIDES(%(A)s)[0] != elemsize)
447
- && (PyArray_STRIDES(%(A)s)[1] != elemsize)))
439
+ if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
448
440
{
441
+ // We can treat the array A as C-or F-contiguous by changing the order of iteration
442
+ if (SA0 < 0){
443
+ A_data += (NA0 -1) * SA0; // Jump to first row
444
+ SA0 = -SA0; // Iterate over rows in reverse
445
+ Sz = -Sz; // Iterate over y in reverse
446
+ }
447
+ if (SA1 < 0){
448
+ A_data += (NA1 -1) * SA1; // Jump to first column
449
+ SA1 = -SA1; // Iterate over columns in reverse
450
+ Sx = -Sx; // Iterate over x in reverse
451
+ }
452
+ } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
453
+ {
454
+ // Array isn't contiguous, we have to make a copy
455
+ // - if the copy is too long, maybe call vector/vector dot on
456
+ // each row instead
457
+ // printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\ n", SA0, SA1);
449
458
npy_intp dims[2];
450
459
dims[0] = NA0;
451
460
dims[1] = NA1;
@@ -458,16 +467,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
458
467
%(A)s = A_copy;
459
468
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
460
469
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
470
+ A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
461
471
}
462
472
463
- if (PyArray_STRIDES(%(A)s)[0] == elemsize )
473
+ if (SA0 == 1 )
464
474
{
465
475
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
466
476
{
467
477
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
468
478
sgemv_(&NOTRANS, &NA0, &NA1,
469
479
&alpha,
470
- (float*)(PyArray_DATA(%(A)s) ), &SA1,
480
+ (float*)(A_data ), &SA1,
471
481
(float*)x_data, &Sx,
472
482
&fbeta,
473
483
(float*)z_data, &Sz);
@@ -477,7 +487,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
477
487
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478
488
dgemv_(&NOTRANS, &NA0, &NA1,
479
489
&alpha,
480
- (double*)(PyArray_DATA(%(A)s) ), &SA1,
490
+ (double*)(A_data ), &SA1,
481
491
(double*)x_data, &Sx,
482
492
&dbeta,
483
493
(double*)z_data, &Sz);
@@ -489,7 +499,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
489
499
%(fail)s
490
500
}
491
501
}
492
- else if (PyArray_STRIDES(%(A)s)[1] == elemsize )
502
+ else if (SA1 == 1 )
493
503
{
494
504
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
495
505
{
@@ -506,14 +516,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
506
516
z_data[0] = 0.f;
507
517
}
508
518
z_data[0] += alpha*sdot_(&NA1,
509
- (float*)(PyArray_DATA(%(A)s) ), &SA1,
519
+ (float*)(A_data ), &SA1,
510
520
(float*)x_data, &Sx);
511
521
}
512
522
else
513
523
{
514
524
sgemv_(&TRANS, &NA1, &NA0,
515
525
&alpha,
516
- (float*)(PyArray_DATA(%(A)s) ), &SA0,
526
+ (float*)(A_data ), &SA0,
517
527
(float*)x_data, &Sx,
518
528
&fbeta,
519
529
(float*)z_data, &Sz);
@@ -534,14 +544,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
534
544
z_data[0] = 0.;
535
545
}
536
546
z_data[0] += alpha*ddot_(&NA1,
537
- (double*)(PyArray_DATA(%(A)s) ), &SA1,
547
+ (double*)(A_data ), &SA1,
538
548
(double*)x_data, &Sx);
539
549
}
540
550
else
541
551
{
542
552
dgemv_(&TRANS, &NA1, &NA0,
543
553
&alpha,
544
- (double*)(PyArray_DATA(%(A)s) ), &SA0,
554
+ (double*)(A_data ), &SA0,
545
555
(double*)x_data, &Sx,
546
556
&dbeta,
547
557
(double*)z_data, &Sz);
@@ -603,7 +613,7 @@ def c_code(self, node, name, inp, out, sub):
603
613
return code
604
614
605
615
def c_code_cache_version (self ):
606
- return (14 , blas_header_version (), check_force_gemv_init ())
616
+ return (15 , blas_header_version (), check_force_gemv_init ())
607
617
608
618
609
619
cgemv_inplace = CGemv (inplace = True )
0 commit comments