Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Fix kernel registration #113372

Merged
merged 1 commit into from
Oct 23, 2024

Conversation

clementval
Copy link
Contributor

The registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well.

@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 22, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-flang-runtime

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

The registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well.


Full diff: https://github.com/llvm/llvm-project/pull/113372.diff

3 Files Affected:

  • (modified) flang/include/flang/Runtime/CUDA/registration.h (+2-1)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp (+8-6)
  • (modified) flang/runtime/CUDA/registration.cpp (+4-3)
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index cbe202c4d23e0d..009715613e29f7 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -20,7 +20,8 @@ extern "C" {
 void *RTDECL(CUFRegisterModule)(void *data);
 
 /// Register a device function.
-void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+void RTDECL(CUFRegisterFunction)(
+    void **module, const char *fctSym, char *fctName);
 
 } // extern "C"
 
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
index c6c9f96b811352..63eac46a997718 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
   llvm::Type *ptrTy = builder.getPtrTy(0);
   llvm::FunctionCallee fct = module->getOrInsertFunction(
       RTNAME_STRING(CUFRegisterFunction),
-      llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
-                              false));
+      llvm::FunctionType::get(
+          ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
   llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
-  builder.CreateCall(
-      fct, {modulePtr, getOrCreateFunctionName(module, builder,
-                                               op.getKernelModuleName().str(),
-                                               op.getKernelName().str())});
+  llvm::Function *fctSym =
+      moduleTranslation.lookupFunction(op.getKernelName().str());
+  builder.CreateCall(fct, {modulePtr, fctSym,
+                           getOrCreateFunctionName(
+                               module, builder, op.getKernelModuleName().str(),
+                               op.getKernelName().str())});
   return mlir::success();
 }
 
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
index e5d9503e95fd8f..22d43a7dc57a3a 100644
--- a/flang/runtime/CUDA/registration.cpp
+++ b/flang/runtime/CUDA/registration.cpp
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
   return fatHandle;
 }
 
-void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
-  __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
-      (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+void RTDEF(CUFRegisterFunction)(
+    void **module, const char *fctSym, char *fctName) {
+  __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
+      (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
 }
 }
 } // namespace Fortran::runtime::cuda

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

The registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well.


Full diff: https://github.com/llvm/llvm-project/pull/113372.diff

3 Files Affected:

  • (modified) flang/include/flang/Runtime/CUDA/registration.h (+2-1)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp (+8-6)
  • (modified) flang/runtime/CUDA/registration.cpp (+4-3)
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index cbe202c4d23e0d..009715613e29f7 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -20,7 +20,8 @@ extern "C" {
 void *RTDECL(CUFRegisterModule)(void *data);
 
 /// Register a device function.
-void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+void RTDECL(CUFRegisterFunction)(
+    void **module, const char *fctSym, char *fctName);
 
 } // extern "C"
 
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
index c6c9f96b811352..63eac46a997718 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
   llvm::Type *ptrTy = builder.getPtrTy(0);
   llvm::FunctionCallee fct = module->getOrInsertFunction(
       RTNAME_STRING(CUFRegisterFunction),
-      llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
-                              false));
+      llvm::FunctionType::get(
+          ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
   llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
-  builder.CreateCall(
-      fct, {modulePtr, getOrCreateFunctionName(module, builder,
-                                               op.getKernelModuleName().str(),
-                                               op.getKernelName().str())});
+  llvm::Function *fctSym =
+      moduleTranslation.lookupFunction(op.getKernelName().str());
+  builder.CreateCall(fct, {modulePtr, fctSym,
+                           getOrCreateFunctionName(
+                               module, builder, op.getKernelModuleName().str(),
+                               op.getKernelName().str())});
   return mlir::success();
 }
 
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
index e5d9503e95fd8f..22d43a7dc57a3a 100644
--- a/flang/runtime/CUDA/registration.cpp
+++ b/flang/runtime/CUDA/registration.cpp
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
   return fatHandle;
 }
 
-void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
-  __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
-      (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+void RTDEF(CUFRegisterFunction)(
+    void **module, const char *fctSym, char *fctName) {
+  __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
+      (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
 }
 }
 } // namespace Fortran::runtime::cuda

Copy link
Contributor

@Renaud-K Renaud-K left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@clementval clementval merged commit 60105ac into llvm:main Oct 23, 2024
12 checks passed
@clementval clementval deleted the cuf_fix_fct_registration branch October 23, 2024 18:26
@frobtech frobtech mentioned this pull request Oct 25, 2024
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
The registration needs the fct pointer and the name. This patch updates
the entry point with an extra arg and the translation as well.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants