@@ -4,7 +4,7 @@ use crate::{
4
4
context:: CublasContext ,
5
5
error:: { Error , ToResult } ,
6
6
raw:: { ComplexLevel1 , FloatLevel1 , Level1 } ,
7
- BlasDatatype ,
7
+ BlasDatatype , Float ,
8
8
} ;
9
9
use cust:: memory:: { GpuBox , GpuBuffer } ;
10
10
use cust:: stream:: Stream ;
@@ -641,4 +641,172 @@ impl CublasContext {
641
641
) -> Result {
642
642
self . rot_strided ( stream, n, x, None , y, None , c, s)
643
643
}
644
+
645
+ /// Constructs the givens rotation matrix that zeros out the second entry of a 2x1 vector.
646
+ pub fn rotg < T : Level1 > (
647
+ & mut self ,
648
+ stream : & Stream ,
649
+ a : & mut impl GpuBox < T > ,
650
+ b : & mut impl GpuBox < T > ,
651
+ c : & mut impl GpuBox < T :: FloatTy > ,
652
+ s : & mut impl GpuBox < T > ,
653
+ ) -> Result {
654
+ self . with_stream ( stream, |ctx| unsafe {
655
+ Ok ( T :: rotg (
656
+ ctx. raw ,
657
+ a. as_device_ptr ( ) . as_mut_ptr ( ) ,
658
+ b. as_device_ptr ( ) . as_mut_ptr ( ) ,
659
+ c. as_device_ptr ( ) . as_mut_ptr ( ) ,
660
+ s. as_device_ptr ( ) . as_mut_ptr ( ) ,
661
+ )
662
+ . to_result ( ) ?)
663
+ } )
664
+ }
665
+
666
+ /// Same as [`CublasContext::rotm`] but with an explicit stride.
667
+ pub fn rotm_strided < T : Level1 + Float > (
668
+ & mut self ,
669
+ stream : & Stream ,
670
+ n : usize ,
671
+ x : & mut impl GpuBuffer < T > ,
672
+ x_stride : Option < usize > ,
673
+ y : & mut impl GpuBuffer < T > ,
674
+ y_stride : Option < usize > ,
675
+ param : & impl GpuBox < T :: FloatTy > ,
676
+ ) -> Result {
677
+ check_stride ( x, n, x_stride) ;
678
+ check_stride ( y, n, y_stride) ;
679
+
680
+ self . with_stream ( stream, |ctx| unsafe {
681
+ Ok ( T :: rotm (
682
+ ctx. raw ,
683
+ n as i32 ,
684
+ x. as_device_ptr ( ) . as_mut_ptr ( ) ,
685
+ x_stride. unwrap_or ( 1 ) as i32 ,
686
+ y. as_device_ptr ( ) . as_mut_ptr ( ) ,
687
+ y_stride. unwrap_or ( 1 ) as i32 ,
688
+ param. as_device_ptr ( ) . as_ptr ( ) ,
689
+ )
690
+ . to_result ( ) ?)
691
+ } )
692
+ }
693
+
694
+ /// Applies the modified givens transformation to vectors `x` and `y`.
695
+ pub fn rotm < T : Level1 + Float > (
696
+ & mut self ,
697
+ stream : & Stream ,
698
+ n : usize ,
699
+ x : & mut impl GpuBuffer < T > ,
700
+ y : & mut impl GpuBuffer < T > ,
701
+ param : & impl GpuBox < T :: FloatTy > ,
702
+ ) -> Result {
703
+ self . rotm_strided ( stream, n, x, None , y, None , param)
704
+ }
705
+
706
+ /// Same as [`CublasContext::rotmg`] but with an explicit stride.
707
+ pub fn rotmg_strided < T : Level1 + Float > (
708
+ & mut self ,
709
+ stream : & Stream ,
710
+ d1 : & mut impl GpuBox < T > ,
711
+ d2 : & mut impl GpuBox < T > ,
712
+ x1 : & mut impl GpuBox < T > ,
713
+ y1 : & mut impl GpuBox < T > ,
714
+ param : & mut impl GpuBox < T > ,
715
+ ) -> Result {
716
+ self . with_stream ( stream, |ctx| unsafe {
717
+ Ok ( T :: rotmg (
718
+ ctx. raw ,
719
+ d1. as_device_ptr ( ) . as_mut_ptr ( ) ,
720
+ d2. as_device_ptr ( ) . as_mut_ptr ( ) ,
721
+ x1. as_device_ptr ( ) . as_mut_ptr ( ) ,
722
+ y1. as_device_ptr ( ) . as_ptr ( ) ,
723
+ param. as_device_ptr ( ) . as_mut_ptr ( ) ,
724
+ )
725
+ . to_result ( ) ?)
726
+ } )
727
+ }
728
+
729
+ /// Constructs the modified givens transformation that zeros out the second entry of a 2x1 vector.
730
+ pub fn rotmg < T : Level1 + Float > (
731
+ & mut self ,
732
+ stream : & Stream ,
733
+ d1 : & mut impl GpuBox < T > ,
734
+ d2 : & mut impl GpuBox < T > ,
735
+ x1 : & mut impl GpuBox < T > ,
736
+ y1 : & mut impl GpuBox < T > ,
737
+ param : & mut impl GpuBox < T > ,
738
+ ) -> Result {
739
+ self . rotmg_strided ( stream, d1, d2, x1, y1, param)
740
+ }
741
+
742
+ /// Same as [`CublasContext::scal`] but with an explicit stride.
743
+ pub fn scal_strided < T : Level1 > (
744
+ & mut self ,
745
+ stream : & Stream ,
746
+ n : usize ,
747
+ alpha : & impl GpuBox < T > ,
748
+ x : & mut impl GpuBuffer < T > ,
749
+ x_stride : Option < usize > ,
750
+ ) -> Result {
751
+ check_stride ( x, n, x_stride) ;
752
+
753
+ self . with_stream ( stream, |ctx| unsafe {
754
+ Ok ( T :: scal (
755
+ ctx. raw ,
756
+ n as i32 ,
757
+ alpha. as_device_ptr ( ) . as_ptr ( ) ,
758
+ x. as_device_ptr ( ) . as_mut_ptr ( ) ,
759
+ x_stride. unwrap_or ( 1 ) as i32 ,
760
+ )
761
+ . to_result ( ) ?)
762
+ } )
763
+ }
764
+
765
+ /// Scales vector `x` by `alpha` and overrides it with the result.
766
+ pub fn scal < T : Level1 > (
767
+ & mut self ,
768
+ stream : & Stream ,
769
+ n : usize ,
770
+ alpha : & impl GpuBox < T > ,
771
+ x : & mut impl GpuBuffer < T > ,
772
+ ) -> Result {
773
+ self . scal_strided ( stream, n, alpha, x, None )
774
+ }
775
+
776
+ /// Same as [`CublasContext::swap`] but with an explicit stride.
777
+ pub fn swap_strided < T : Level1 > (
778
+ & mut self ,
779
+ stream : & Stream ,
780
+ n : usize ,
781
+ x : & mut impl GpuBuffer < T > ,
782
+ x_stride : Option < usize > ,
783
+ y : & mut impl GpuBuffer < T > ,
784
+ y_stride : Option < usize > ,
785
+ ) -> Result {
786
+ check_stride ( x, n, x_stride) ;
787
+ check_stride ( y, n, y_stride) ;
788
+
789
+ self . with_stream ( stream, |ctx| unsafe {
790
+ Ok ( T :: swap (
791
+ ctx. raw ,
792
+ n as i32 ,
793
+ x. as_device_ptr ( ) . as_mut_ptr ( ) ,
794
+ x_stride. unwrap_or ( 1 ) as i32 ,
795
+ y. as_device_ptr ( ) . as_mut_ptr ( ) ,
796
+ y_stride. unwrap_or ( 1 ) as i32 ,
797
+ )
798
+ . to_result ( ) ?)
799
+ } )
800
+ }
801
+
802
+ /// Swaps vectors `x` and `y`.
803
+ pub fn swap < T : Level1 > (
804
+ & mut self ,
805
+ stream : & Stream ,
806
+ n : usize ,
807
+ x : & mut impl GpuBuffer < T > ,
808
+ y : & mut impl GpuBuffer < T > ,
809
+ ) -> Result {
810
+ self . swap_strided ( stream, n, x, None , y, None )
811
+ }
644
812
}
0 commit comments