Skip to content

Commit 05fd4d5

Browse files
authored
[flang][cuda] Perform inlined assignment when field is c_devptr (llvm#124322)
When a field in a derived type is `c_devptr`, keep check if we can do a memcpy instead of falling back to the runtime assignment. Many internal CUDA Fortran derived type have a `c_devptr` field and this would lead to stack overflow on the device if the assignment is performed by the runtime function.
1 parent 4b209c5 commit 05fd4d5

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1410,7 +1410,8 @@ static bool recordTypeCanBeMemCopied(fir::RecordType recordType) {
14101410
for (auto [_, fieldType] : recordType.getTypeList()) {
14111411
// Derived type component may have user assignment (so far, we cannot tell
14121412
// in FIR, so assume it is always the case, TODO: get the actual info).
1413-
if (mlir::isa<fir::RecordType>(fir::unwrapSequenceType(fieldType)))
1413+
if (mlir::isa<fir::RecordType>(fir::unwrapSequenceType(fieldType)) &&
1414+
!fir::isa_builtin_c_devptr_type(fir::unwrapSequenceType(fieldType)))
14141415
return false;
14151416
// Allocatable components need deep copy.
14161417
if (auto boxType = mlir::dyn_cast<fir::BaseBoxType>(fieldType))

flang/test/Lower/CUDA/cuda-devptr.cuf

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44

55
module cudafct
66
use __fortran_builtins, only : c_devptr => __builtin_c_devptr
7+
8+
type :: t1
9+
type(c_devptr) :: devp
10+
integer :: a
11+
end type
12+
713
contains
814
function c_devloc(x)
915
use iso_c_binding, only: c_loc
@@ -12,6 +18,10 @@ contains
1218
real, target, device :: x
1319
c_devloc%cptr = c_loc(x)
1420
end function
21+
22+
attributes(device) function get_t1()
23+
type(t1) :: get_t1
24+
end
1525
end
1626

1727
subroutine sub1()
@@ -68,3 +78,12 @@ end subroutine
6878
! CHECK: %[[P_ADDR_COORD:.*]] = fir.coordinate_of %[[P_CPTR_COORD]], %[[ADDRESS_FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
6979
! CHECK: %[[ADDR:.*]] = fir.load %[[RES_ADDR_COORD]] : !fir.ref<i64>
7080
! CHECK: fir.store %[[ADDR]] to %[[P_ADDR_COORD]] : !fir.ref<i64>
81+
82+
attributes(global) subroutine assign_nested_c_devptr(p, a)
83+
use cudafct
84+
type(t1), device :: p
85+
p = get_t1()
86+
end subroutine
87+
88+
! CHECK-LABEL: func.func @_QPassign_nested_c_devptr
89+
! CHECK-NOT: fir.call @_FortranAAssign

0 commit comments

Comments
 (0)