diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h index cbe202c4d23e0..009715613e29f 100644 --- a/flang/include/flang/Runtime/CUDA/registration.h +++ b/flang/include/flang/Runtime/CUDA/registration.h @@ -20,7 +20,8 @@ extern "C" { void *RTDECL(CUFRegisterModule)(void *data); /// Register a device function. -void RTDECL(CUFRegisterFunction)(void **module, const char *fct); +void RTDECL(CUFRegisterFunction)( + void **module, const char *fctSym, char *fctName); } // extern "C" diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp index c6c9f96b81135..63eac46a99771 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp @@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op, llvm::Type *ptrTy = builder.getPtrTy(0); llvm::FunctionCallee fct = module->getOrInsertFunction( RTNAME_STRING(CUFRegisterFunction), - llvm::FunctionType::get(ptrTy, ArrayRef({ptrTy, ptrTy}), - false)); + llvm::FunctionType::get( + ptrTy, ArrayRef({ptrTy, ptrTy, ptrTy}), false)); llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr()); - builder.CreateCall( - fct, {modulePtr, getOrCreateFunctionName(module, builder, - op.getKernelModuleName().str(), - op.getKernelName().str())}); + llvm::Function *fctSym = + moduleTranslation.lookupFunction(op.getKernelName().str()); + builder.CreateCall(fct, {modulePtr, fctSym, + getOrCreateFunctionName( + module, builder, op.getKernelModuleName().str(), + op.getKernelName().str())}); return mlir::success(); } diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp index e5d9503e95fd8..22d43a7dc57a3 100644 --- a/flang/runtime/CUDA/registration.cpp +++ b/flang/runtime/CUDA/registration.cpp @@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) { return fatHandle; } -void RTDEF(CUFRegisterFunction)(void **module, const char *fct) { - __cudaRegisterFunction(module, fct, const_cast(fct), fct, -1, - (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0); +void RTDEF(CUFRegisterFunction)( + void **module, const char *fctSym, char *fctName) { + __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0, + (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0); } } } // namespace Fortran::runtime::cuda