diff --git a/src/compiler/nvcc.rs b/src/compiler/nvcc.rs index a25d7d828..925d7029c 100644 --- a/src/compiler/nvcc.rs +++ b/src/compiler/nvcc.rs @@ -655,41 +655,66 @@ where // but can optionally be run in parallel to other groups if the user requested via // `nvcc --threads`. - let mut no_more_groups = false; - let mut command_groups: Vec> = vec![]; - let preprocessor_flag = match host_compiler { NvccHostCompiler::Msvc => "-P", _ => "-E", } .to_owned(); - for (_, dir, exe, args) in all_commands { - if log_enabled!(log::Level::Trace) { - trace!( - "transformed nvcc command: {:?}", - [ - &[format!("cd {} &&", dir.to_string_lossy()).to_string()], - &[exe.to_str().unwrap_or_default().to_string()][..], - &args[..] - ] - .concat() - .join(" ") - ); - } + let gen_module_id_file_flag = "--gen_module_id_file".to_owned(); + let mut cuda_front_end_group = Vec::::new(); + let mut final_assembly_group = Vec::::new(); + let mut device_compile_groups = HashMap::>::new(); - let (env_vars, cacheable) = match exe.file_stem().and_then(|s| s.to_str()) { + for (_, dir, exe, args) in all_commands { + let mut args = args.clone(); + + if let (env_vars, cacheable, Some(group)) = match exe.file_stem().and_then(|s| s.to_str()) { + // fatbinary and nvlink are not cacheable + Some("fatbinary") | Some("nvlink") => ( + env_vars.clone(), + Cacheable::No, + Some(&mut final_assembly_group), + ), // cicc and ptxas are cacheable - Some("cicc") | Some("ptxas") => (env_vars.clone(), Cacheable::Yes), - // cudafe++, nvlink, and fatbinary are not cacheable - Some("cudafe++") | Some("nvlink") => (env_vars.clone(), Cacheable::No), - Some("fatbinary") => { - // The fatbinary command represents the start of the last group - if !no_more_groups { - command_groups.push(vec![]); + Some("cicc") => { + // Remove the `--gen_module_id_file` flag + if let Some(idx) = args.iter().position(|x| x == &gen_module_id_file_flag) { + args.splice(idx..idx + 1, []); + } + let group = device_compile_groups.get_mut(&args[args.len() - 3]); + (env_vars.clone(), Cacheable::Yes, group) + } + Some("ptxas") => { + // Remove the `--gen_module_id_file` flag + if let Some(idx) = args.iter().position(|x| x == &gen_module_id_file_flag) { + args.splice(idx..idx + 1, []); + } + let group = device_compile_groups.values_mut().find(|cmds| { + if let Some(cicc) = cmds.last() { + if let Some(cicc_out) = cicc.args.last() { + return cicc_out == &args[args.len() - 3]; + } + } + false + }); + (env_vars.clone(), Cacheable::Yes, group) + } + // cudafe++ is not cacheable + Some("cudafe++") => { + // Fix for CTK < 12.0: + // Add `--gen_module_id_file` if the cudafe++ args include `--module_id_file_name` + if !args.contains(&gen_module_id_file_flag) { + if let Some(idx) = args.iter().position(|x| x == "--module_id_file_name") { + // Insert `--gen_module_id_file` just before `--module_id_file_name` to match nvcc behavior + args.splice(idx..idx, [gen_module_id_file_flag.clone()]); + } } - no_more_groups = true; - (env_vars.clone(), Cacheable::No) + ( + env_vars.clone(), + Cacheable::No, + Some(&mut cuda_front_end_group), + ) } _ => { // All generated host compiler commands include one of these defines. @@ -705,13 +730,35 @@ where continue; } if args.contains(&preprocessor_flag) { - // Each preprocessor step represents the start of a new command - // group, unless it comes after a call to fatbinary. - if !no_more_groups { - command_groups.push(vec![]); + // Each preprocessor step represents the start of a new command group + if let Some(out_file) = args.last().and_then(|o| { + PathBuf::from(o).file_name().and_then(|o| { + o.to_str().and_then(|o| { + // If the output file ends with... + // * .cpp1.ii - cicc/ptxas input + // * .cpp4.ii - cudafe++ input + if o.ends_with(".cpp1.ii") { + Some(o.to_owned()) + } else { + None + } + }) + }) + }) { + let new_device_compile_group = vec![]; + device_compile_groups.insert(out_file.clone(), new_device_compile_group); + ( + env_vars.clone(), + Cacheable::No, + device_compile_groups.get_mut(&out_file), + ) + } else { + ( + env_vars.clone(), + Cacheable::No, + Some(&mut cuda_front_end_group), + ) } - // Do not run preprocessor calls through sccache - (env_vars.clone(), Cacheable::No) } else { // Returns Cacheable::Yes to indicate we _do_ want to run this host // compiler call through sccache (because it may be distributed), @@ -732,31 +779,40 @@ where .cloned() .collect::>(), Cacheable::Yes, + Some(&mut final_assembly_group), ) } } - }; + } { + if log_enabled!(log::Level::Trace) { + trace!( + "transformed nvcc command: {:?}", + [ + &[format!("cd {} &&", dir.to_string_lossy()).to_string()], + &[exe.to_str().unwrap_or_default().to_string()][..], + &args[..] + ] + .concat() + .join(" ") + ); + } - // Initialize the first group in case the first command isn't a call to the host preprocessor, - // i.e. `nvcc -o test.o -c test.c` - if command_groups.is_empty() { - command_groups.push(vec![]); + group.push(NvccGeneratedSubcommand { + exe: exe.clone(), + args: args.clone(), + cwd: dir.into(), + env_vars, + cacheable, + }); } - - match command_groups.last_mut() { - None => {} - Some(group) => { - group.push(NvccGeneratedSubcommand { - exe: exe.clone(), - args: args.clone(), - cwd: dir.into(), - env_vars, - cacheable, - }); - } - }; } + let mut command_groups = vec![]; + + command_groups.push(cuda_front_end_group); + command_groups.extend(device_compile_groups.into_values()); + command_groups.push(final_assembly_group); + Ok(command_groups) }