Skip to content

Commit c67efb7

Browse files
committed
Feat: finish level 1 ops for cublas
1 parent c1ed73f commit c67efb7

File tree

3 files changed

+275
-1
lines changed

3 files changed

+275
-1
lines changed

crates/blastoff/src/context.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ bitflags::bitflags! {
5858
/// - [Conjugated Complex Dot Product <span style="float:right;">`dotc`</span>](CublasContext::dotc)
5959
/// - [Euclidian Norm <span style="float:right;">`nrm2`</span>](CublasContext::nrm2)
6060
/// - [Rotate points in the xy-plane using a Givens rotation matrix <span style="float:right;">`rot`</span>](CublasContext::rot)
61+
/// - [Construct the givens rotation matrix that zeros the second entry of a vector<span style="float:right;">`rotg`</span>](CublasContext::rotg)
62+
/// - [Apply the modified Givens transformation to vectors <span style="float:right;">`rotm`</span>](CublasContext::rotm)
63+
/// - [Construct the modified givens rotation matrix that zeros the second entry of a vector<span style="float:right;">`rotmg`</span>](CublasContext::rotmg)
64+
/// - [Scale a vector by a scalar <span style="float:right;">`scal`</span>](CublasContext::scal)
65+
/// - [Swap two vectors <span style="float:right;">`swap`</span>](CublasContext::swap)
6166
#[derive(Debug)]
6267
pub struct CublasContext {
6368
pub(crate) raw: sys::v2::cublasHandle_t,

crates/blastoff/src/level1.rs

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::{
44
context::CublasContext,
55
error::{Error, ToResult},
66
raw::{ComplexLevel1, FloatLevel1, Level1},
7-
BlasDatatype,
7+
BlasDatatype, Float,
88
};
99
use cust::memory::{GpuBox, GpuBuffer};
1010
use cust::stream::Stream;
@@ -641,4 +641,172 @@ impl CublasContext {
641641
) -> Result {
642642
self.rot_strided(stream, n, x, None, y, None, c, s)
643643
}
644+
645+
/// Constructs the givens rotation matrix that zeros out the second entry of a 2x1 vector.
646+
pub fn rotg<T: Level1>(
647+
&mut self,
648+
stream: &Stream,
649+
a: &mut impl GpuBox<T>,
650+
b: &mut impl GpuBox<T>,
651+
c: &mut impl GpuBox<T::FloatTy>,
652+
s: &mut impl GpuBox<T>,
653+
) -> Result {
654+
self.with_stream(stream, |ctx| unsafe {
655+
Ok(T::rotg(
656+
ctx.raw,
657+
a.as_device_ptr().as_mut_ptr(),
658+
b.as_device_ptr().as_mut_ptr(),
659+
c.as_device_ptr().as_mut_ptr(),
660+
s.as_device_ptr().as_mut_ptr(),
661+
)
662+
.to_result()?)
663+
})
664+
}
665+
666+
/// Same as [`CublasContext::rotm`] but with an explicit stride.
667+
pub fn rotm_strided<T: Level1 + Float>(
668+
&mut self,
669+
stream: &Stream,
670+
n: usize,
671+
x: &mut impl GpuBuffer<T>,
672+
x_stride: Option<usize>,
673+
y: &mut impl GpuBuffer<T>,
674+
y_stride: Option<usize>,
675+
param: &impl GpuBox<T::FloatTy>,
676+
) -> Result {
677+
check_stride(x, n, x_stride);
678+
check_stride(y, n, y_stride);
679+
680+
self.with_stream(stream, |ctx| unsafe {
681+
Ok(T::rotm(
682+
ctx.raw,
683+
n as i32,
684+
x.as_device_ptr().as_mut_ptr(),
685+
x_stride.unwrap_or(1) as i32,
686+
y.as_device_ptr().as_mut_ptr(),
687+
y_stride.unwrap_or(1) as i32,
688+
param.as_device_ptr().as_ptr(),
689+
)
690+
.to_result()?)
691+
})
692+
}
693+
694+
/// Applies the modified givens transformation to vectors `x` and `y`.
695+
pub fn rotm<T: Level1 + Float>(
696+
&mut self,
697+
stream: &Stream,
698+
n: usize,
699+
x: &mut impl GpuBuffer<T>,
700+
y: &mut impl GpuBuffer<T>,
701+
param: &impl GpuBox<T::FloatTy>,
702+
) -> Result {
703+
self.rotm_strided(stream, n, x, None, y, None, param)
704+
}
705+
706+
/// Same as [`CublasContext::rotmg`] but with an explicit stride.
707+
pub fn rotmg_strided<T: Level1 + Float>(
708+
&mut self,
709+
stream: &Stream,
710+
d1: &mut impl GpuBox<T>,
711+
d2: &mut impl GpuBox<T>,
712+
x1: &mut impl GpuBox<T>,
713+
y1: &mut impl GpuBox<T>,
714+
param: &mut impl GpuBox<T>,
715+
) -> Result {
716+
self.with_stream(stream, |ctx| unsafe {
717+
Ok(T::rotmg(
718+
ctx.raw,
719+
d1.as_device_ptr().as_mut_ptr(),
720+
d2.as_device_ptr().as_mut_ptr(),
721+
x1.as_device_ptr().as_mut_ptr(),
722+
y1.as_device_ptr().as_ptr(),
723+
param.as_device_ptr().as_mut_ptr(),
724+
)
725+
.to_result()?)
726+
})
727+
}
728+
729+
/// Constructs the modified givens transformation that zeros out the second entry of a 2x1 vector.
730+
pub fn rotmg<T: Level1 + Float>(
731+
&mut self,
732+
stream: &Stream,
733+
d1: &mut impl GpuBox<T>,
734+
d2: &mut impl GpuBox<T>,
735+
x1: &mut impl GpuBox<T>,
736+
y1: &mut impl GpuBox<T>,
737+
param: &mut impl GpuBox<T>,
738+
) -> Result {
739+
self.rotmg_strided(stream, d1, d2, x1, y1, param)
740+
}
741+
742+
/// Same as [`CublasContext::scal`] but with an explicit stride.
743+
pub fn scal_strided<T: Level1>(
744+
&mut self,
745+
stream: &Stream,
746+
n: usize,
747+
alpha: &impl GpuBox<T>,
748+
x: &mut impl GpuBuffer<T>,
749+
x_stride: Option<usize>,
750+
) -> Result {
751+
check_stride(x, n, x_stride);
752+
753+
self.with_stream(stream, |ctx| unsafe {
754+
Ok(T::scal(
755+
ctx.raw,
756+
n as i32,
757+
alpha.as_device_ptr().as_ptr(),
758+
x.as_device_ptr().as_mut_ptr(),
759+
x_stride.unwrap_or(1) as i32,
760+
)
761+
.to_result()?)
762+
})
763+
}
764+
765+
/// Scales vector `x` by `alpha` and overrides it with the result.
766+
pub fn scal<T: Level1>(
767+
&mut self,
768+
stream: &Stream,
769+
n: usize,
770+
alpha: &impl GpuBox<T>,
771+
x: &mut impl GpuBuffer<T>,
772+
) -> Result {
773+
self.scal_strided(stream, n, alpha, x, None)
774+
}
775+
776+
/// Same as [`CublasContext::swap`] but with an explicit stride.
777+
pub fn swap_strided<T: Level1>(
778+
&mut self,
779+
stream: &Stream,
780+
n: usize,
781+
x: &mut impl GpuBuffer<T>,
782+
x_stride: Option<usize>,
783+
y: &mut impl GpuBuffer<T>,
784+
y_stride: Option<usize>,
785+
) -> Result {
786+
check_stride(x, n, x_stride);
787+
check_stride(y, n, y_stride);
788+
789+
self.with_stream(stream, |ctx| unsafe {
790+
Ok(T::swap(
791+
ctx.raw,
792+
n as i32,
793+
x.as_device_ptr().as_mut_ptr(),
794+
x_stride.unwrap_or(1) as i32,
795+
y.as_device_ptr().as_mut_ptr(),
796+
y_stride.unwrap_or(1) as i32,
797+
)
798+
.to_result()?)
799+
})
800+
}
801+
802+
/// Swaps vectors `x` and `y`.
803+
pub fn swap<T: Level1>(
804+
&mut self,
805+
stream: &Stream,
806+
n: usize,
807+
x: &mut impl GpuBuffer<T>,
808+
y: &mut impl GpuBuffer<T>,
809+
) -> Result {
810+
self.swap_strided(stream, n, x, None, y, None)
811+
}
644812
}

crates/blastoff/src/raw/level1.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ pub trait Level1: BlasDatatype {
5959
c: *mut Self::FloatTy,
6060
s: *mut Self,
6161
) -> cublasStatus_t;
62+
unsafe fn rotm(
63+
handle: cublasHandle_t,
64+
n: c_int,
65+
x: *mut Self,
66+
incx: c_int,
67+
y: *mut Self,
68+
incy: c_int,
69+
param: *const Self::FloatTy,
70+
) -> cublasStatus_t;
71+
unsafe fn rotmg(
72+
handle: cublasHandle_t,
73+
d1: *mut Self,
74+
d2: *mut Self,
75+
x1: *mut Self,
76+
y1: *const Self,
77+
param: *mut Self,
78+
) -> cublasStatus_t;
6279
unsafe fn scal(
6380
handle: cublasHandle_t,
6481
n: c_int,
@@ -146,6 +163,27 @@ impl Level1 for f32 {
146163
) -> cublasStatus_t {
147164
cublasSrotg_v2(handle, a, b, c, s)
148165
}
166+
unsafe fn rotm(
167+
handle: cublasHandle_t,
168+
n: c_int,
169+
x: *mut Self,
170+
incx: c_int,
171+
y: *mut Self,
172+
incy: c_int,
173+
param: *const Self::FloatTy,
174+
) -> cublasStatus_t {
175+
cublasSrotm_v2(handle, n, x, incx, y, incy, param)
176+
}
177+
unsafe fn rotmg(
178+
handle: cublasHandle_t,
179+
d1: *mut Self,
180+
d2: *mut Self,
181+
x1: *mut Self,
182+
y1: *const Self,
183+
param: *mut Self,
184+
) -> cublasStatus_t {
185+
cublasSrotmg_v2(handle, d1, d2, x1, y1, param)
186+
}
149187
unsafe fn scal(
150188
handle: cublasHandle_t,
151189
n: c_int,
@@ -237,6 +275,27 @@ impl Level1 for f64 {
237275
) -> cublasStatus_t {
238276
cublasDrotg_v2(handle, a, b, c, s)
239277
}
278+
unsafe fn rotm(
279+
handle: cublasHandle_t,
280+
n: c_int,
281+
x: *mut Self,
282+
incx: c_int,
283+
y: *mut Self,
284+
incy: c_int,
285+
param: *const Self::FloatTy,
286+
) -> cublasStatus_t {
287+
cublasDrotm_v2(handle, n, x, incx, y, incy, param)
288+
}
289+
unsafe fn rotmg(
290+
handle: cublasHandle_t,
291+
d1: *mut Self,
292+
d2: *mut Self,
293+
x1: *mut Self,
294+
y1: *const Self,
295+
param: *mut Self,
296+
) -> cublasStatus_t {
297+
cublasDrotmg_v2(handle, d1, d2, x1, y1, param)
298+
}
240299
unsafe fn scal(
241300
handle: cublasHandle_t,
242301
n: c_int,
@@ -328,6 +387,27 @@ impl Level1 for Complex32 {
328387
) -> cublasStatus_t {
329388
cublasCrotg_v2(handle, a.cast(), b.cast(), c, s.cast())
330389
}
390+
unsafe fn rotm(
391+
_handle: cublasHandle_t,
392+
_n: c_int,
393+
_x: *mut Self,
394+
_incx: c_int,
395+
_y: *mut Self,
396+
_incy: c_int,
397+
_param: *const Self::FloatTy,
398+
) -> cublasStatus_t {
399+
unreachable!()
400+
}
401+
unsafe fn rotmg(
402+
_handle: cublasHandle_t,
403+
_d1: *mut Self,
404+
_d2: *mut Self,
405+
_x1: *mut Self,
406+
_y1: *const Self,
407+
_param: *mut Self,
408+
) -> cublasStatus_t {
409+
unreachable!()
410+
}
331411
unsafe fn scal(
332412
handle: cublasHandle_t,
333413
n: c_int,
@@ -419,6 +499,27 @@ impl Level1 for Complex64 {
419499
) -> cublasStatus_t {
420500
cublasZrotg_v2(handle, a.cast(), b.cast(), c, s.cast())
421501
}
502+
unsafe fn rotm(
503+
_handle: cublasHandle_t,
504+
_n: c_int,
505+
_x: *mut Self,
506+
_incx: c_int,
507+
_y: *mut Self,
508+
_incy: c_int,
509+
_param: *const Self::FloatTy,
510+
) -> cublasStatus_t {
511+
unreachable!()
512+
}
513+
unsafe fn rotmg(
514+
_handle: cublasHandle_t,
515+
_d1: *mut Self,
516+
_d2: *mut Self,
517+
_x1: *mut Self,
518+
_y1: *const Self,
519+
_param: *mut Self,
520+
) -> cublasStatus_t {
521+
unreachable!()
522+
}
422523
unsafe fn scal(
423524
handle: cublasHandle_t,
424525
n: c_int,

0 commit comments

Comments
 (0)