Skip to content

Commit 4a6e8b7

Browse files
authored
Improve CUDA support (#612)
* Perform CUDA --device-link. This allows to perform the final link with system linker. * Add 'cudart' method mimicking the '--cudart' nvcc command-line option. Try to locate the library in standard location relative to nvcc command. If it fails, user is held responsible for specifying one in RUSTFLAGS. * Add dummy CUDA test to cc-test. Execution is bound to fail without card, but the failure is ignored. It's rather a compile-n-link test. The test is suppressed if 'nvcc' is not found on the $PATH. * Add dummy CUDA CI test. * Harmonize CUDA support with NVCC default --cudart static. This can interfere with current deployments in the wild, in which case some adjustments might be required. Most notably one might have to add .cuda("none") to the corresponding Builder instantiation to restore the original behaviour.
1 parent a11e066 commit 4a6e8b7

File tree

7 files changed

+160
-0
lines changed

7 files changed

+160
-0
lines changed

.github/workflows/main.yml

+19
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,25 @@ jobs:
8787
- run: cargo test ${{ matrix.no_run }} --manifest-path cc-test/Cargo.toml --target ${{ matrix.target }} --features parallel
8888
- run: cargo test ${{ matrix.no_run }} --manifest-path cc-test/Cargo.toml --target ${{ matrix.target }} --release
8989

90+
cuda:
91+
name: Test CUDA support
92+
runs-on: ubuntu-20.04
93+
steps:
94+
- uses: actions/checkout@master
95+
- name: Install cuda-minimal-build-11-4
96+
shell: bash
97+
run: |
98+
# https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network
99+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin
100+
sudo mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600
101+
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
102+
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /"
103+
sudo apt-get update
104+
sudo apt-get -y install cuda-minimal-build-11-4
105+
- name: Test 'cudart' feature
106+
shell: bash
107+
run: env PATH=/usr/local/cuda/bin:$PATH cargo test --manifest-path cc-test/Cargo.toml --features test_cuda
108+
90109
msrv:
91110
name: MSRV
92111
runs-on: ${{ matrix.os }}

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ fn main() {
189189
cc::Build::new()
190190
// Switch to CUDA C++ library compilation using NVCC.
191191
.cuda(true)
192+
.cudart("static")
192193
// Generate code for Maxwell (GTX 970, 980, 980 Ti, Titan X).
193194
.flag("-gencode").flag("arch=compute_52,code=sm_52")
194195
// Generate code for Maxwell (Jetson TX1).

cc-test/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ test = false
1111

1212
[build-dependencies]
1313
cc = { path = ".." }
14+
which = "^4.0"
1415

1516
[features]
1617
parallel = ["cc/parallel"]
18+
test_cuda = []

cc-test/build.rs

+18
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,24 @@ fn main() {
3535
.cpp(true)
3636
.compile("baz");
3737

38+
if env::var("CARGO_FEATURE_TEST_CUDA").is_ok() {
39+
// Detect if there is CUDA compiler and engage "cuda" feature.
40+
let nvcc = match env::var("NVCC") {
41+
Ok(var) => which::which(var),
42+
Err(_) => which::which("nvcc"),
43+
};
44+
if nvcc.is_ok() {
45+
cc::Build::new()
46+
.cuda(true)
47+
.cudart("static")
48+
.file("src/cuda.cu")
49+
.compile("libcuda.a");
50+
51+
// Communicate [cfg(feature = "cuda")] to test/all.rs.
52+
println!("cargo:rustc-cfg=feature=\"cuda\"");
53+
}
54+
}
55+
3856
if target.contains("windows") {
3957
cc::Build::new().file("src/windows.c").compile("windows");
4058
}

cc-test/src/cuda.cu

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <cuda.h>
2+
3+
__global__ void kernel() {}
4+
5+
extern "C" void cuda_kernel() { kernel<<<1, 1>>>(); }

cc-test/tests/all.rs

+11
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,14 @@ fn opt_linkage() {
5656
assert_eq!(answer(), 42);
5757
}
5858
}
59+
60+
#[cfg(feature = "cuda")]
61+
#[test]
62+
fn cuda_here() {
63+
extern "C" {
64+
fn cuda_kernel();
65+
}
66+
unsafe {
67+
cuda_kernel();
68+
}
69+
}

src/lib.rs

+104
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ pub struct Build {
103103
cpp_link_stdlib: Option<Option<String>>,
104104
cpp_set_stdlib: Option<String>,
105105
cuda: bool,
106+
cudart: Option<String>,
106107
target: Option<String>,
107108
host: Option<String>,
108109
out_dir: Option<PathBuf>,
@@ -298,6 +299,7 @@ impl Build {
298299
cpp_link_stdlib: None,
299300
cpp_set_stdlib: None,
300301
cuda: false,
302+
cudart: None,
301303
target: None,
302304
host: None,
303305
out_dir: None,
@@ -611,6 +613,20 @@ impl Build {
611613
self.cuda = cuda;
612614
if cuda {
613615
self.cpp = true;
616+
self.cudart = Some("static".to_string());
617+
}
618+
self
619+
}
620+
621+
/// Link CUDA run-time.
622+
///
623+
/// This option mimics the `--cudart` NVCC command-line option. Just like
624+
/// the original it accepts `{none|shared|static}`, with default being
625+
/// `static`. The method has to be invoked after `.cuda(true)`, or not
626+
/// at all, if the default is right for the project.
627+
pub fn cudart(&mut self, cudart: &str) -> &mut Build {
628+
if self.cuda {
629+
self.cudart = Some(cudart.to_string());
614630
}
615631
self
616632
}
@@ -996,6 +1012,56 @@ impl Build {
9961012
}
9971013
}
9981014

1015+
let cudart = match &self.cudart {
1016+
Some(opt) => opt.as_str(), // {none|shared|static}
1017+
None => "none",
1018+
};
1019+
if cudart != "none" {
1020+
if let Some(nvcc) = which(&self.get_compiler().path) {
1021+
// Try to figure out the -L search path. If it fails,
1022+
// it's on user to specify one by passing it through
1023+
// RUSTFLAGS environment variable.
1024+
let mut libtst = false;
1025+
let mut libdir = nvcc;
1026+
libdir.pop(); // remove 'nvcc'
1027+
libdir.push("..");
1028+
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
1029+
if cfg!(target_os = "linux") {
1030+
libdir.push("targets");
1031+
libdir.push(target_arch.to_owned() + "-linux");
1032+
libdir.push("lib");
1033+
libtst = true;
1034+
} else if cfg!(target_env = "msvc") {
1035+
libdir.push("lib");
1036+
match target_arch.as_str() {
1037+
"x86_64" => {
1038+
libdir.push("x64");
1039+
libtst = true;
1040+
}
1041+
"x86" => {
1042+
libdir.push("Win32");
1043+
libtst = true;
1044+
}
1045+
_ => libtst = false,
1046+
}
1047+
}
1048+
if libtst && libdir.is_dir() {
1049+
println!(
1050+
"cargo:rustc-link-search=native={}",
1051+
libdir.to_str().unwrap()
1052+
);
1053+
}
1054+
1055+
// And now the -l flag.
1056+
let lib = match cudart {
1057+
"shared" => "cudart",
1058+
"static" => "cudart_static",
1059+
bad => panic!("unsupported cudart option: {}", bad),
1060+
};
1061+
println!("cargo:rustc-link-lib={}", lib);
1062+
}
1063+
}
1064+
9991065
Ok(())
10001066
}
10011067

@@ -1205,6 +1271,9 @@ impl Build {
12051271
if !msvc || !is_asm || !is_arm {
12061272
cmd.arg("-c");
12071273
}
1274+
if self.cuda && self.files.len() > 1 {
1275+
cmd.arg("--device-c");
1276+
}
12081277
cmd.arg(&obj.src);
12091278
if cfg!(target_os = "macos") {
12101279
self.fix_env_for_apple_os(&mut cmd)?;
@@ -1811,6 +1880,21 @@ impl Build {
18111880
self.assemble_progressive(dst, chunk)?;
18121881
}
18131882

1883+
if self.cuda {
1884+
// Link the device-side code and add it to the target library,
1885+
// so that non-CUDA linker can link the final binary.
1886+
1887+
let out_dir = self.get_out_dir()?;
1888+
let dlink = out_dir.join(lib_name.to_owned() + "_dlink.o");
1889+
let mut nvcc = self.get_compiler().to_command();
1890+
nvcc.arg("--device-link")
1891+
.arg("-o")
1892+
.arg(dlink.clone())
1893+
.arg(dst);
1894+
run(&mut nvcc, "nvcc")?;
1895+
self.assemble_progressive(dst, &[dlink])?;
1896+
}
1897+
18141898
let target = self.get_target()?;
18151899
if target.contains("msvc") {
18161900
// The Rust compiler will look for libfoo.a and foo.lib, but the
@@ -3100,3 +3184,23 @@ fn map_darwin_target_from_rust_to_compiler_architecture(target: &str) -> Option<
31003184
None
31013185
}
31023186
}
3187+
3188+
fn which(tool: &Path) -> Option<PathBuf> {
3189+
fn check_exe(exe: &mut PathBuf) -> bool {
3190+
let exe_ext = std::env::consts::EXE_EXTENSION;
3191+
exe.exists() || (!exe_ext.is_empty() && exe.set_extension(exe_ext) && exe.exists())
3192+
}
3193+
3194+
// If |tool| is not just one "word," assume it's an actual path...
3195+
if tool.components().count() > 1 {
3196+
let mut exe = PathBuf::from(tool);
3197+
return if check_exe(&mut exe) { Some(exe) } else { None };
3198+
}
3199+
3200+
// Loop through PATH entries searching for the |tool|.
3201+
let path_entries = env::var_os("PATH")?;
3202+
env::split_paths(&path_entries).find_map(|path_entry| {
3203+
let mut exe = path_entry.join(tool);
3204+
return if check_exe(&mut exe) { Some(exe) } else { None };
3205+
})
3206+
}

0 commit comments

Comments
 (0)