-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Fix kernel registration #113372
Conversation
@llvm/pr-subscribers-flang-runtime Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThe registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well. Full diff: https://github.com/llvm/llvm-project/pull/113372.diff 3 Files Affected:
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index cbe202c4d23e0d..009715613e29f7 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -20,7 +20,8 @@ extern "C" {
void *RTDECL(CUFRegisterModule)(void *data);
/// Register a device function.
-void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+void RTDECL(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName);
} // extern "C"
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
index c6c9f96b811352..63eac46a997718 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterFunction),
- llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
- false));
+ llvm::FunctionType::get(
+ ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
- builder.CreateCall(
- fct, {modulePtr, getOrCreateFunctionName(module, builder,
- op.getKernelModuleName().str(),
- op.getKernelName().str())});
+ llvm::Function *fctSym =
+ moduleTranslation.lookupFunction(op.getKernelName().str());
+ builder.CreateCall(fct, {modulePtr, fctSym,
+ getOrCreateFunctionName(
+ module, builder, op.getKernelModuleName().str(),
+ op.getKernelName().str())});
return mlir::success();
}
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
index e5d9503e95fd8f..22d43a7dc57a3a 100644
--- a/flang/runtime/CUDA/registration.cpp
+++ b/flang/runtime/CUDA/registration.cpp
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
return fatHandle;
}
-void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
- __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
- (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+void RTDEF(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName) {
+ __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
+ (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
}
}
} // namespace Fortran::runtime::cuda
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThe registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well. Full diff: https://github.com/llvm/llvm-project/pull/113372.diff 3 Files Affected:
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index cbe202c4d23e0d..009715613e29f7 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -20,7 +20,8 @@ extern "C" {
void *RTDECL(CUFRegisterModule)(void *data);
/// Register a device function.
-void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+void RTDECL(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName);
} // extern "C"
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
index c6c9f96b811352..63eac46a997718 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterFunction),
- llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
- false));
+ llvm::FunctionType::get(
+ ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
- builder.CreateCall(
- fct, {modulePtr, getOrCreateFunctionName(module, builder,
- op.getKernelModuleName().str(),
- op.getKernelName().str())});
+ llvm::Function *fctSym =
+ moduleTranslation.lookupFunction(op.getKernelName().str());
+ builder.CreateCall(fct, {modulePtr, fctSym,
+ getOrCreateFunctionName(
+ module, builder, op.getKernelModuleName().str(),
+ op.getKernelName().str())});
return mlir::success();
}
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
index e5d9503e95fd8f..22d43a7dc57a3a 100644
--- a/flang/runtime/CUDA/registration.cpp
+++ b/flang/runtime/CUDA/registration.cpp
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
return fatHandle;
}
-void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
- __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
- (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+void RTDEF(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName) {
+ __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
+ (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
}
}
} // namespace Fortran::runtime::cuda
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
The registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well.
The registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well.