Skip to content

Commit

Permalink
Support nvc as the host compiler for nvcc
Browse files Browse the repository at this point in the history
  • Loading branch information
robertmaynard authored and sylvestre committed Sep 28, 2023
1 parent a39a121 commit 896cef1
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 30 deletions.
79 changes: 61 additions & 18 deletions src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::compiler::gcc::Gcc;
use crate::compiler::msvc;
use crate::compiler::msvc::Msvc;
use crate::compiler::nvcc::Nvcc;
use crate::compiler::nvcc::NvccHostCompiler;
use crate::compiler::rust::{Rust, RustupProxy};
use crate::compiler::tasking_vx::TaskingVX;
#[cfg(feature = "dist-client")]
Expand Down Expand Up @@ -1085,10 +1086,12 @@ where
//
// We prefix the information we need with `compiler_id` and `compiler_version`
// so that we can support compilers that insert pre-amble code even in `-E` mode
let test = b"#if defined(__NVCC__) && !defined(_MSC_VER)
compiler_id=nvcc
let test = b"#if defined(__NVCC__) && defined(__NVCOMPILER)
compiler_id=nvcc-nvhpc
#elif defined(__NVCC__) && defined(_MSC_VER)
compiler_id=nvcc-msvc
#elif defined(__NVCC__)
compiler_id=nvcc
#elif defined(_MSC_VER) && !defined(__clang__)
compiler_id=msvc
#elif defined(_MSC_VER) && defined(_MT)
Expand Down Expand Up @@ -1228,12 +1231,16 @@ compiler_version=__VERSION__
.await
.map(|c| Box::new(c) as Box<dyn Compiler<T>>);
}
"nvcc" | "nvcc-msvc" => {
let is_msvc = kind == "nvcc-msvc";
debug!("Found NVCC");
"nvcc" | "nvcc-msvc" | "nvcc-nvhpc" => {
let host_compiler = match kind {
"nvcc-nvhpc" => NvccHostCompiler::NVHPC,
"nvcc-msvc" => NvccHostCompiler::MSVC,
"nvcc" => NvccHostCompiler::GCC,
&_ => NvccHostCompiler::GCC,
};
return CCompiler::new(
Nvcc {
is_msvc,
host_compiler: host_compiler,
version: version.clone(),
},
executable,
Expand Down Expand Up @@ -1298,7 +1305,10 @@ mod test {
let creator = new_creator();
let runtime = single_threaded_runtime();
let pool = runtime.handle();
next_command(&creator, Ok(MockChild::new(exit_status(0), "\n\ncompiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "\n\ncompiler_id=gcc", "")),
);
let c = detect_compiler(creator, &f.bins[0], f.tempdir.path(), &[], &[], pool, None)
.wait()
.unwrap()
Expand All @@ -1312,7 +1322,10 @@ mod test {
let creator = new_creator();
let runtime = single_threaded_runtime();
let pool = runtime.handle();
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=clang\n", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=clang\n", "")),
);
let c = detect_compiler(creator, &f.bins[0], f.tempdir.path(), &[], &[], pool, None)
.wait()
.unwrap()
Expand All @@ -1335,7 +1348,10 @@ mod test {
let prefix = String::from("blah: ");
let stdout = format!("{}{}\r\n", prefix, s);
// Compiler detection output
next_command(&creator, Ok(MockChild::new(exit_status(0), "\ncompiler_id=msvc\n", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "\ncompiler_id=msvc\n", "")),
);
// showincludes prefix detection output
next_command(
&creator,
Expand All @@ -1354,7 +1370,10 @@ mod test {
let creator = new_creator();
let runtime = single_threaded_runtime();
let pool = runtime.handle();
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=nvcc\n", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=nvcc\n", "")),
);
let c = detect_compiler(creator, &f.bins[0], f.tempdir.path(), &[], &[], pool, None)
.wait()
.unwrap()
Expand Down Expand Up @@ -1486,7 +1505,10 @@ LLVM version: 6.0",
let creator = new_creator();
let runtime = single_threaded_runtime();
let pool = runtime.handle();
next_command(&creator, Ok(MockChild::new(exit_status(0), "\ncompiler_id=diab\n", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "\ncompiler_id=diab\n", "")),
);
let c = detect_compiler(creator, &f.bins[0], f.tempdir.path(), &[], &[], pool, None)
.wait()
.unwrap()
Expand Down Expand Up @@ -1588,7 +1610,10 @@ LLVM version: 6.0",
let pool = runtime.handle();
let f = TestFixture::new();
// Pretend to be GCC.
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")),
);
let c = get_compiler_info(creator, &f.bins[0], f.tempdir.path(), &[], &[], pool, None)
.wait()
.unwrap()
Expand All @@ -1607,7 +1632,10 @@ LLVM version: 6.0",
let storage = DiskCache::new(f.tempdir.path().join("cache"), u64::MAX, &pool);
let storage = Arc::new(storage);
// Pretend to be GCC.
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")),
);
let c = get_compiler_info(
creator.clone(),
&f.bins[0],
Expand Down Expand Up @@ -1718,7 +1746,10 @@ LLVM version: 6.0",
let storage = DiskCache::new(f.tempdir.path().join("cache"), u64::MAX, &pool);
let storage = Arc::new(storage);
// Pretend to be GCC.
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")),
);
let c = get_compiler_info(
creator.clone(),
&f.bins[0],
Expand Down Expand Up @@ -1825,7 +1856,10 @@ LLVM version: 6.0",
let storage = MockStorage::new(None);
let storage: Arc<MockStorage> = Arc::new(storage);
// Pretend to be GCC.
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")),
);
let c = get_compiler_info(
creator.clone(),
&f.bins[0],
Expand Down Expand Up @@ -1906,7 +1940,10 @@ LLVM version: 6.0",
let storage = MockStorage::new(Some(storage_delay));
let storage: Arc<MockStorage> = Arc::new(storage);
// Pretend to be GCC.
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")),
);
let c = get_compiler_info(
creator.clone(),
&f.bins[0],
Expand Down Expand Up @@ -1979,7 +2016,10 @@ LLVM version: 6.0",
let storage = DiskCache::new(f.tempdir.path().join("cache"), u64::MAX, &pool);
let storage = Arc::new(storage);
// Pretend to be GCC.
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")),
);
let c = get_compiler_info(
creator.clone(),
&f.bins[0],
Expand Down Expand Up @@ -2170,7 +2210,10 @@ LLVM version: 6.0",
let storage = DiskCache::new(f.tempdir.path().join("cache"), u64::MAX, &pool);
let storage = Arc::new(storage);
// Pretend to be GCC.
next_command(&creator, Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")));
next_command(
&creator,
Ok(MockChild::new(exit_status(0), "compiler_id=gcc", "")),
);
let c = get_compiler_info(
creator.clone(),
&f.bins[0],
Expand Down
90 changes: 78 additions & 12 deletions src/compiler/nvcc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ use std::process;
use crate::errors::*;

/// A unit struct on which to implement `CCompilerImpl`.
#[derive(Clone, Debug)]
pub enum NvccHostCompiler {
GCC,
MSVC,
NVHPC,
}

#[derive(Clone, Debug)]
pub struct Nvcc {
pub is_msvc: bool,
pub host_compiler: NvccHostCompiler,
pub version: Option<String>,
}

Expand Down Expand Up @@ -149,12 +156,17 @@ impl CCompilerImpl for Nvcc {

//NVCC only supports `-E` when it comes after preprocessor
//and common flags.
let no_line_nums = match self.is_msvc {
true => "-Xcompiler=-EP",
false => "-Xcompiler=-P",
//
// nvc/nvc++ don't support no line numbers to console
// msvc requires the `-EP` flag to output no line numbers to console
// other host compilers are presumed to match `gcc` behavior
let no_line_num_flag = match self.host_compiler {
NvccHostCompiler::NVHPC => "",
NvccHostCompiler::MSVC => "-Xcompiler=-EP",
NvccHostCompiler::GCC => "-Xcompiler=-P",
};
cmd.arg("-E")
.arg(no_line_nums)
.arg(no_line_num_flag)
.env_clear()
.envs(env_vars.iter().map(|&(ref k, ref v)| (k, v)))
.current_dir(cwd);
Expand Down Expand Up @@ -262,18 +274,34 @@ mod test {
use std::collections::HashMap;
use std::path::PathBuf;

fn parse_arguments_(arguments: Vec<String>) -> CompilerArguments<ParsedArguments> {
fn parse_arguments_gcc(arguments: Vec<String>) -> CompilerArguments<ParsedArguments> {
let arguments = arguments.iter().map(OsString::from).collect::<Vec<_>>();
Nvcc {
host_compiler: NvccHostCompiler::GCC,
version: None,
}
.parse_arguments(&arguments, ".".as_ref())
}
fn parse_arguments_nvc(arguments: Vec<String>) -> CompilerArguments<ParsedArguments> {
let arguments = arguments.iter().map(OsString::from).collect::<Vec<_>>();
Nvcc {
is_msvc: false,
host_compiler: NvccHostCompiler::NVHPC,
version: None,
}
.parse_arguments(&arguments, ".".as_ref())
}

macro_rules! parses {
( $( $s:expr ),* ) => {
match parse_arguments_(vec![ $( $s.to_string(), )* ]) {
match parse_arguments_gcc(vec![ $( $s.to_string(), )* ]) {
CompilerArguments::Ok(a) => a,
o => panic!("Got unexpected parse result: {:?}", o),
}
}
}
macro_rules! parses_nvc {
( $( $s:expr ),* ) => {
match parse_arguments_nvc(vec![ $( $s.to_string(), )* ]) {
CompilerArguments::Ok(a) => a,
o => panic!("Got unexpected parse result: {:?}", o),
}
Expand All @@ -300,7 +328,7 @@ mod test {
}

#[test]
fn test_parse_arguments_simple_cu() {
fn test_parse_arguments_simple_cu_gcc() {
let a = parses!("-c", "foo.cu", "-o", "foo.o");
assert_eq!(Some("foo.cu"), a.input.to_str());
assert_eq!(Language::Cuda, a.language);
Expand All @@ -318,6 +346,25 @@ mod test {
assert!(a.common_args.is_empty());
}

#[test]
fn test_parse_arguments_simple_cu_nvc() {
let a = parses_nvc!("-c", "foo.cu", "-o", "foo.o");
assert_eq!(Some("foo.cu"), a.input.to_str());
assert_eq!(Language::Cuda, a.language);
assert_map_contains!(
a.outputs,
(
"obj",
ArtifactDescriptor {
path: "foo.o".into(),
optional: false
}
)
);
assert!(a.preprocessor_args.is_empty());
assert!(a.common_args.is_empty());
}

#[test]
fn test_parse_arguments_ccbin_no_path() {
let a = parses!("-ccbin=gcc", "-c", "foo.cu", "-o", "foo.o");
Expand Down Expand Up @@ -663,7 +710,18 @@ mod test {
fn test_parse_dlink_is_not_compilation() {
assert_eq!(
CompilerArguments::NotCompilation,
parse_arguments_(stringvec![
parse_arguments_gcc(stringvec![
"-forward-unknown-to-host-compiler",
"--generate-code=arch=compute_50,code=[compute_50,sm_50,sm_52]",
"-dlink",
"main.cu.o",
"-o",
"device_link.o"
])
);
assert_eq!(
CompilerArguments::NotCompilation,
parse_arguments_nvc(stringvec![
"-forward-unknown-to-host-compiler",
"--generate-code=arch=compute_50,code=[compute_50,sm_50,sm_52]",
"-dlink",
Expand All @@ -677,12 +735,20 @@ mod test {
fn test_parse_cant_cache_flags() {
assert_eq!(
CompilerArguments::CannotCache("-E", None),
parse_arguments_(stringvec!["-x", "cu", "-c", "foo.c", "-o", "foo.o", "-E"])
parse_arguments_gcc(stringvec!["-x", "cu", "-c", "foo.c", "-o", "foo.o", "-E"])
);
assert_eq!(
CompilerArguments::CannotCache("-E", None),
parse_arguments_nvc(stringvec!["-x", "cu", "-c", "foo.c", "-o", "foo.o", "-E"])
);

assert_eq!(
CompilerArguments::CannotCache("-M", None),
parse_arguments_(stringvec!["-x", "cu", "-c", "foo.c", "-o", "foo.o", "-M"])
parse_arguments_gcc(stringvec!["-x", "cu", "-c", "foo.c", "-o", "foo.o", "-M"])
);
assert_eq!(
CompilerArguments::CannotCache("-M", None),
parse_arguments_nvc(stringvec!["-x", "cu", "-c", "foo.c", "-o", "foo.o", "-M"])
);
}
}

0 comments on commit 896cef1

Please sign in to comment.