Skip to content

Commit 786d15a

Browse files
ngiloq6cpsauer
authored andcommitted
generate nvcc flags with a script
1 parent 8866a19 commit 786d15a

File tree

2 files changed

+227
-10
lines changed

2 files changed

+227
-10
lines changed

nvcc_clang_diff.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python3
2+
3+
"""Generates a set of flags that are accepted by nvcc but not clang.
4+
5+
These should be stripped or rewritten before being stored in compile_commands.json.
6+
"""
7+
8+
import dataclasses
9+
import functools
10+
import shutil
11+
import subprocess
12+
13+
@functools.total_ordering
14+
@dataclasses.dataclass
15+
class Flag:
16+
long: str
17+
short: str
18+
has_args: bool
19+
20+
def __lt__(self, other):
21+
return (self.long, self.short) < (other.long, other.short)
22+
23+
def flag_key(flag):
24+
if "=" in flag:
25+
return flag[:flag.index("=")]
26+
return flag
27+
28+
def get_nvcc_flags() -> list[Flag]:
29+
nvcc = shutil.which("nvcc") or "/usr/local/cuda/bin/nvcc"
30+
help_output = subprocess.check_output([nvcc, "--help"], text=True, stderr=subprocess.STDOUT)
31+
flags = []
32+
for line in help_output.splitlines():
33+
if not line.startswith("--"):
34+
continue
35+
# looks like --long args (-short)
36+
line_parts = line.split()
37+
short = line_parts[-1]
38+
if short.startswith("(") and short.endswith(")"):
39+
short = short[1:-1]
40+
flags.append(Flag(line_parts[0], short, has_args = len(line_parts) > 2))
41+
return flags
42+
43+
def get_clang_flags() -> set[str]:
44+
clang = shutil.which("clang") or "/usr/bin/clang"
45+
help_output = subprocess.check_output([clang, "--help"], text=True, stderr=subprocess.STDOUT)
46+
flags = set(flag_key(token) for token in help_output.split() if token.startswith("-"))
47+
# Fix this up manually based on https://clang.llvm.org/docs/ClangCommandLineReference.html
48+
flags |= {"-Wreorder", "-Wno-deprecated-declarations", "-Werror", "-O", "--help", "-l", "-m64", "--shared", "-shared"}
49+
return flags
50+
51+
def main():
52+
nvcc_flags = get_nvcc_flags()
53+
clang_flags = get_clang_flags()
54+
55+
nvcc_flags_no_arg = []
56+
nvcc_flags_with_arg = []
57+
nvcc_rewrite_flags = {}
58+
for nvcc_flag in nvcc_flags:
59+
if nvcc_flag.long in clang_flags and nvcc_flag.short in clang_flags:
60+
continue
61+
if nvcc_flag.short in clang_flags:
62+
nvcc_rewrite_flags[nvcc_flag.long] = nvcc_flag.short
63+
continue
64+
if nvcc_flag.long in clang_flags:
65+
nvcc_rewrite_flags[nvcc_flag.short] = nvcc_flag.long
66+
continue
67+
if nvcc_flag.has_args:
68+
nvcc_flags_with_arg.append(nvcc_flag)
69+
else:
70+
nvcc_flags_no_arg.append(nvcc_flag)
71+
72+
print("_nvcc_flags_no_arg = {")
73+
print(" # long name, short name")
74+
for nvcc_flag in sorted(nvcc_flags_no_arg):
75+
print(f" '{nvcc_flag.long}', '{nvcc_flag.short}',")
76+
print("}")
77+
78+
print("_nvcc_flags_with_arg = {")
79+
print(" # long name, short name")
80+
for nvcc_flag in sorted(nvcc_flags_with_arg):
81+
print(f" '{nvcc_flag.long}', '{nvcc_flag.short}',")
82+
print("}")
83+
84+
print("_nvcc_rewrite_flags = {")
85+
print(" # NVCC flag: clang flag")
86+
for nvcc_flag in sorted(nvcc_rewrite_flags):
87+
clang_flag = nvcc_rewrite_flags[nvcc_flag]
88+
print(f" '{nvcc_flag}': '{clang_flag}',")
89+
print("}")
90+
91+
if __name__ == "__main__":
92+
main()

refresh.template.py

Lines changed: 135 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -771,21 +771,144 @@ def _all_platform_patch(compile_args: typing.List[str]):
771771

772772
return compile_args
773773

774+
# Generated script nvcc_clang_diff.py
774775
_nvcc_flags_no_arg = {
775776
# long name, short name
776-
'--expt-relaxed-constexpr', '-expt-relaxed-constexpr',
777+
'--Wdefault-stream-launch', '-Wdefault-stream-launch',
778+
'--Wext-lambda-captures-this', '-Wext-lambda-captures-this',
779+
'--Wmissing-launch-bounds', '-Wmissing-launch-bounds',
780+
'--Wno-deprecated-gpu-targets', '-Wno-deprecated-gpu-targets',
781+
'--allow-unsupported-compiler', '-allow-unsupported-compiler',
782+
'--augment-host-linker-script', '-aug-hls',
783+
'--clean-targets', '-clean',
784+
'--compile-as-tools-patch', '-astoolspatch',
785+
'--cubin', '-cubin',
786+
'--cuda', '-cuda',
787+
'--device-c', '-dc',
788+
'--device-link', '-dlink',
789+
'--device-w', '-dw',
790+
'--display-error-number', '-err-no',
791+
'--dlink-time-opt', '-dlto',
792+
'--dont-use-profile', '-noprof',
793+
'--dryrun', '-dryrun',
777794
'--expt-extended-lambda', '-expt-extended-lambda',
778-
'--extended-lambda', '-extended-lambda'}
779-
_nvcc_flags_with_arg = (
795+
'--expt-relaxed-constexpr', '-expt-relaxed-constexpr',
796+
'--extended-lambda', '-extended-lambda',
797+
'--extensible-whole-program', '-ewp',
798+
'--extra-device-vectorization', '-extra-device-vectorization',
799+
'--fatbin', '-fatbin',
800+
'--forward-unknown-opts', '-forward-unknown-opts',
801+
'--forward-unknown-to-host-compiler', '-forward-unknown-to-host-compiler',
802+
'--forward-unknown-to-host-linker', '-forward-unknown-to-host-linker',
803+
'--gen-opt-lto', '-gen-opt-lto',
804+
'--generate-line-info', '-lineinfo',
805+
'--host-relocatable-link', '-r',
806+
'--keep', '-keep',
807+
'--keep-device-functions', '-keep-device-functions',
808+
'--lib', '-lib',
809+
'--link', '-link',
810+
'--list-gpu-arch', '-arch-ls',
811+
'--list-gpu-code', '-code-ls',
812+
'--lto', '-lto',
813+
'--no-align-double', '--no-align-double',
814+
'--no-compress', '-no-compress',
815+
'--no-device-link', '-nodlink',
816+
'--no-display-error-number', '-no-err-no',
817+
'--no-exceptions', '-noeh',
818+
'--no-host-device-initializer-list', '-nohdinitlist',
819+
'--no-host-device-move-forward', '-nohdmoveforward',
820+
'--objdir-as-tempdir', '-objtemp',
821+
'--optix-ir', '-optix-ir',
822+
'--ptx', '-ptx',
823+
'--qpp-config', '-qpp-config',
824+
'--resource-usage', '-res-usage',
825+
'--restrict', '-restrict',
826+
'--run', '-run',
827+
'--source-in-ptx', '-src-in-ptx',
828+
'--use-local-env', '-use-local-env',
829+
'--use_fast_math', '-use_fast_math',
830+
}
831+
_nvcc_flags_with_arg = {
780832
# long name, short name
781-
'--relocatable-device-code', '-rdc',
833+
'--archive-options', '-Xarchive',
834+
'--archiver-binary', '-arbin',
835+
'--brief-diagnostics', '-brief-diag',
782836
'--compiler-bindir', '-ccbin',
783-
'--compiler-options', '-Xcompiler')
837+
'--compiler-options', '-Xcompiler',
838+
'--cudadevrt', '-cudadevrt',
839+
'--cudart', '-cudart',
840+
'--default-stream', '-default-stream',
841+
'--dependency-drive-prefix', '-ddp',
842+
'--diag-error', '-diag-error',
843+
'--diag-suppress', '-diag-suppress',
844+
'--diag-warn', '-diag-warn',
845+
'--dopt', '-dopt',
846+
'--drive-prefix', '-dp',
847+
'--entries', '-e',
848+
'--fmad', '-fmad',
849+
'--ftemplate-backtrace-limit', '-ftemplate-backtrace-limit',
850+
'--ftemplate-depth', '-ftemplate-depth',
851+
'--ftz', '-ftz',
852+
'--generate-code', '-gencode',
853+
'--gpu-code', '-code',
854+
'--host-linker-script', '-hls',
855+
'--input-drive-prefix', '-idp',
856+
'--keep-dir', '-keep-dir',
857+
'--libdevice-directory', '-ldir',
858+
'--machine', '-m',
859+
'--maxrregcount', '-maxrregcount',
860+
'--nvlink-options', '-Xnvlink',
861+
'--optimization-info', '-opt-info',
862+
'--options-file', '-optf',
863+
'--output-directory', '-odir',
864+
'--prec-div', '-prec-div',
865+
'--prec-sqrt', '-prec-sqrt',
866+
'--ptxas-options', '-Xptxas',
867+
'--relocatable-device-code', '-rdc',
868+
'--run-args', '-run-args',
869+
'--split-compile', '-split-compile',
870+
'--target-directory', '-target-dir',
871+
'--threads', '-t',
872+
'--version-ident', '-dQ',
873+
}
784874
_nvcc_rewrite_flags = {
785-
# NVCC flag: equiavelent clang flag
786-
"--output-file": "-o",
787-
"--std": "-std",
788-
"--x": "-x"}
875+
# NVCC flag: clang flag
876+
'--Werror': '-Werror',
877+
'--Wno-deprecated-declarations': '-Wno-deprecated-declarations',
878+
'--Wreorder': '-Wreorder',
879+
'--compile': '-c',
880+
'--debug': '-g',
881+
'--define-macro': '-D',
882+
'--dependency-output': '-MF',
883+
'--dependency-target-name': '-MT',
884+
'--device-debug': '-G',
885+
'--disable-warnings': '-w',
886+
'--generate-dependencies': '-M',
887+
'--generate-dependencies-with-compile': '-MD',
888+
'--generate-dependency-targets': '-MP',
889+
'--generate-nonsystem-dependencies': '-MM',
890+
'--generate-nonsystem-dependencies-with-compile': '-MMD',
891+
'--gpu-architecture': '-arch',
892+
'--include-path': '-I',
893+
'--library': '-l',
894+
'--library-path': '-L',
895+
'--linker-options': '-Xlinker',
896+
'--m64': '-m64',
897+
'--optimize': '-O',
898+
'--output-file': '-o',
899+
'--pre-include': '-include',
900+
'--preprocess': '-E',
901+
'--profile': '-pg',
902+
'--save-temps': '-save-temps',
903+
'--std': '-std',
904+
'--system-include': '-isystem',
905+
'--time': '-time',
906+
'--undefine-macro': '-U',
907+
'--verbose': '-v',
908+
'--x': '-x',
909+
'-V': '--version',
910+
'-h': '--help',
911+
}
789912

790913
def _nvcc_patch(compile_args: typing.List[str]) -> typing.List[str]:
791914
"""Apply fixes to args to nvcc.
@@ -838,11 +961,13 @@ def _get_cpp_command_for_files(compile_action):
838961
# Patch command by platform
839962
compile_action.arguments = _all_platform_patch(compile_action.arguments)
840963
compile_action.arguments = _apple_platform_patch(compile_action.arguments)
841-
compile_action.arguments = _nvcc_patch(compile_action.arguments)
842964
# Android and Linux and grailbio LLVM toolchains: Fine as is; no special patching needed.
843965

844966
source_files, header_files = _get_files(compile_action)
845967

968+
# Done after getting files since we may execute NVCC to get the files.
969+
compile_action.arguments = _nvcc_patch(compile_action.arguments)
970+
846971
return source_files, header_files, compile_action.arguments
847972

848973

0 commit comments

Comments
 (0)