@@ -771,17 +771,21 @@ def _all_platform_patch(compile_args: typing.List[str]):
771
771
772
772
return compile_args
773
773
774
- _nvcc_flags_no_arg = (
774
+ _nvcc_flags_no_arg = {
775
775
# long name, short name
776
776
'--expt-relaxed-constexpr' , '-expt-relaxed-constexpr' ,
777
777
'--expt-extended-lambda' , '-expt-extended-lambda' ,
778
- '--extended-lambda' , '-extended-lambda' ,
779
- )
778
+ '--extended-lambda' , '-extended-lambda' }
780
779
_nvcc_flags_with_arg = (
781
780
# long name, short name
782
781
'--relocatable-device-code' , '-rdc' ,
783
782
'--compiler-bindir' , '-ccbin' ,
784
783
'--compiler-options' , '-Xcompiler' )
784
+ _nvcc_rewrite_flags = {
785
+ # NVCC flag: equiavelent clang flag
786
+ "--output-file" : "-o" ,
787
+ "--std" : "-std" ,
788
+ "--x" : "-x" }
785
789
786
790
def _nvcc_patch (compile_args : typing .List [str ]) -> typing .List [str ]:
787
791
"""Apply fixes to args to nvcc.
@@ -798,13 +802,17 @@ def _nvcc_patch(compile_args: typing.List[str]) -> typing.List[str]:
798
802
# Make clangd's behavior closer to nvcc's.
799
803
# I think this might become the default in clangd 17: https://reviews.llvm.org/D151359
800
804
'-Xclang' , '-fcuda-allow-variadic-functions' ]
801
- skip_next = True # skip the first arg which we added above
805
+ skip_next = True # skip the compile_args[0] which we added above
802
806
for arg in compile_args :
803
807
if skip_next :
804
808
skip_next = False
805
809
continue
806
810
if arg in _nvcc_flags_no_arg :
807
811
continue
812
+ rewrite_to = _nvcc_rewrite_flags .get (arg )
813
+ if rewrite_to :
814
+ new_compile_args .append (rewrite_to )
815
+ continue
808
816
skip = False
809
817
for flag_with_arg in _nvcc_flags_with_arg :
810
818
if arg == flag_with_arg :
0 commit comments