From 45c88ea3e64f3924ed533c23598983f9753a8e3b Mon Sep 17 00:00:00 2001 From: intwanghao Date: Wed, 18 Dec 2024 14:53:05 +0800 Subject: [PATCH] fix Signed-off-by: intwanghao --- clang/lib/DPCT/AnalysisInfo.cpp | 2 + clang/lib/DPCT/RulesLang/RulesLang.cpp | 2 +- clang/lib/DPCT/RulesLang/RulesLang.h | 8 ++ clang/runtime/dpct-rt/include/dpct/kernel.hpp | 95 ++++++++++++++++--- clang/runtime/dpct-rt/include/dpct/util.hpp | 11 +++ clang/test/dpct/function_pointer.cu | 7 +- .../dpct/launch-kernel-cooperative-usm.cu | 4 +- clang/test/dpct/launch-kernel-cooperative.cu | 2 +- clang/test/dpct/launch-kernel-usm.cu | 4 +- clang/test/dpct/launch-kernel.cu | 2 +- 10 files changed, 116 insertions(+), 21 deletions(-) diff --git a/clang/lib/DPCT/AnalysisInfo.cpp b/clang/lib/DPCT/AnalysisInfo.cpp index 20f15217f342..628f5e816beb 100644 --- a/clang/lib/DPCT/AnalysisInfo.cpp +++ b/clang/lib/DPCT/AnalysisInfo.cpp @@ -5215,6 +5215,8 @@ void DeviceFunctionDecl::insertWrapper() { << "nd_range<3> &nr, unsigned int localMemSize, void " "**kernelParams, void **extra)"; } else { + Printer.line("// Auto generated SYCL kernel wrapper used to migration " + "kernel function pointer."); if (!TParamsInfo.empty()) { Printer << "template<"; for (size_t i = 0; i < TParamsInfo.size(); i++) { diff --git a/clang/lib/DPCT/RulesLang/RulesLang.cpp b/clang/lib/DPCT/RulesLang/RulesLang.cpp index bf4d38cd1d1e..b2a074212069 100644 --- a/clang/lib/DPCT/RulesLang/RulesLang.cpp +++ b/clang/lib/DPCT/RulesLang/RulesLang.cpp @@ -4513,7 +4513,7 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node, } emplaceTransformation(new InsertBeforeStmt( E, MapNames::getDpctNamespace() + "wrapper_register" + TypeRepl + "(")); - emplaceTransformation(new InsertAfterStmt(E, ")")); + emplaceTransformation(new InsertAfterStmt(E, ").get()")); } void KernelCallRefRule::runRule( diff --git a/clang/lib/DPCT/RulesLang/RulesLang.h b/clang/lib/DPCT/RulesLang/RulesLang.h index 381a72cfc71a..afda6edd1bf3 100644 --- a/clang/lib/DPCT/RulesLang/RulesLang.h +++ b/clang/lib/DPCT/RulesLang/RulesLang.h @@ -449,6 +449,14 @@ class KernelCallRule : public NamedMigrationRule { SourceLocation &EpilogLocation); }; +/// Migration rule for kernel function references. +/// This rule handles kernel functions that are used as function pointers. +/// For such kernel functions, a wrapper is generated with a `_wrapper` +/// postfix added to the kernel function name. Additionally, if the kernel +/// function pointer is used in a context where its original type information is +/// erased (e.g., raw pointer usage), an extra wrapper registration is required. +/// This ensures that the raw pointer is associated with the appropriate wrapper +/// and retains the necessary type information. class KernelCallRefRule : public NamedMigrationRule { std::string getTypeRepl(const Expr *E); template diff --git a/clang/runtime/dpct-rt/include/dpct/kernel.hpp b/clang/runtime/dpct-rt/include/dpct/kernel.hpp index 540e68553e72..f5ca266168f6 100644 --- a/clang/runtime/dpct-rt/include/dpct/kernel.hpp +++ b/clang/runtime/dpct-rt/include/dpct/kernel.hpp @@ -444,6 +444,30 @@ static inline void invoke_kernel_function(dpct::kernel_function &function, localMemSize, kernelParams, extra); } +/// Utility class for launching SYCL kernels through auto generated kernel +/// function wrapper. +/// For example: +/// A SYCL kernel function and auto generated wrapper: +/// void kernel_func(int *ptr, sycl::nd_item<3> item); +/// void kernel_func_wrapper(int *ptr) { +/// sycl::queue queue = *dpct::kernel_launch::_que; +/// unsigned int localMemSize = dpct::kernel_launch::_local_mem_size; +/// sycl::nd_range<3> nr = dpct::kernel_launch::_nr; +/// queue.parallel_for( +/// nr, +/// [=](sycl::nd_item<3> item_ct1) { +/// kernel_func(ptr, item_ct1); +/// }); +/// } +/// Then launch the kernel through auto generated wrapper like: +/// typedef void(*fpt)(int *); +/// fpt fp = kernel_func_wrapper; +/// dpct::kernel_launch::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0, +/// device_ptr); +/// If the origin function type is erased, then need to register it first: +/// void *fp = (void *)wrapper_register(&kernel_func_wrapper); +/// dpct::kernel_launch::launch(fp, dpct::dim3(1), dpct::dim3(1), args, 0, +/// 0); class kernel_launch { template static void launch_helper(FuncT &&func, ArgSelector &selector, @@ -464,21 +488,37 @@ class kernel_launch { }; public: + /// Variables for storing execution configuration. static inline thread_local sycl::queue *_que = nullptr; static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>(); static inline thread_local unsigned int _local_mem_size = 0; + /// Map for retrieving launchable functor from a raw pointer. static inline std::map< const void *, std::function> - wrapper_map = {}; + kernel_function_ptr_map = {}; - static void regifter_kernel_launcher( + /// Registers a kernel function pointer with a corresponding launchable + /// functor. + /// \param [in] func Pointer to the kernel function. + /// \param [in] launcher Functor to handle kernel invocation. + static void regifter_kernel_ptr( const void *func, std::function launcher) { - wrapper_map[func] = std::move(launcher); + kernel_function_ptr_map[func] = std::move(launcher); } - + /// Launches a kernel function with arguments provided directly through + /// auto generated kernel function wrapper. + /// \tparam FuncT Type of the auto generated kernel function wrapper. + /// \tparam ArgsT Types of kernel arguments. + /// \param [in] func Pointer to the auto generated kernel function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] local_mem_size The size of local memory required by the kernel + /// function. + /// \param [in] que SYCL queue used to execute kernel. + /// \param [in] args Kernel arguments. template static void launch(FuncT *func, dim3 group_range, dim3 local_range, unsigned int local_mem_size, queue_ptr que, @@ -486,16 +526,34 @@ class kernel_launch { set_execution_config(group_range, local_range, local_mem_size, que); func(args...); } - + /// Launches a kernel function through registered auto generated kernel + /// function wrapper. + /// \param [in] func Pointer to the registered auto generated kernel + /// function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] args Array of pointers to kernel arguments. + /// \param [in] local_mem_size The size of local memory required by the kernel + /// function. + /// \param [in] que SYCL queue used to execute kernel. static void launch(const void *func, dim3 group_range, dim3 local_range, void **args, unsigned int local_mem_size, queue_ptr que) { - wrapper_map[func](group_range, local_range, args, local_mem_size, que); + kernel_function_ptr_map[func](group_range, local_range, args, + local_mem_size, que); } - + /// Launches a kernel function with packed arguments through auto generated + /// kernel function wrapper. + /// \tparam FuncT Type of the auto generated kernel function wrapper. + /// \param [in] func Pointer to the auto generated kernel function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] args Array of pointers to kernel arguments. + /// \param [in] local_mem_size The size of local memory required by the kernel + /// function. + /// \param [in] que SYCL queue used to execute kernel. template - static typename std::enable_if::value, void>::type - launch(FuncT *func, dim3 group_range, dim3 local_range, void **args, - unsigned int local_mem_size, queue_ptr que) { + static void launch(FuncT *func, dim3 group_range, dim3 local_range, + void **args, unsigned int local_mem_size, queue_ptr que) { constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num; set_execution_config(group_range, local_range, local_mem_size, que); args_selector selector(args, nullptr); @@ -503,23 +561,38 @@ class kernel_launch { } }; +/// Helper class to register and invoke kernel functions through a wrapper. template class wrapper_register; template class wrapper_register { public: typedef Ret (*FT)(Args...); FT func; + /// Constructor to register a kernel function pointer. + /// \param [in] fp Pointer to the kernel function. wrapper_register(FT fp) : func(fp) { - kernel_launch::regifter_kernel_launcher((void *)func, *this); + kernel_launch::regifter_kernel_ptr((void *)func, *this); } + /// Invokes the kernel function through the stored kernel function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] args Array of pointers to kernel arguments. + /// \param [in] local_mem_size The size of local memory required by the kernel + /// function. + /// \param [in] que SYCL queue used to execute kernel. void operator()(dim3 group_range, dim3 local_range, void **args, unsigned int local_mem_size, queue_ptr que) { kernel_launch::launch(func, group_range, local_range, args, local_mem_size, que); } + /// Retrieves the original kernel function pointer. + /// \return The original kernel function pointer. const FT &get() const noexcept { return func; } + /// Implicit conversion to the original kernel function pointer. + /// \return The original kernel function pointer. operator FT() const noexcept { return func; } }; +/// Deduction guide for wrapper_register. template wrapper_register(Ret (*)(Args...)) -> wrapper_register; diff --git a/clang/runtime/dpct-rt/include/dpct/util.hpp b/clang/runtime/dpct-rt/include/dpct/util.hpp index d30a3194d11f..91f40c479976 100644 --- a/clang/runtime/dpct-rt/include/dpct/util.hpp +++ b/clang/runtime/dpct-rt/include/dpct/util.hpp @@ -1187,6 +1187,17 @@ class args_selector { } }; +/// \brief This struct template used to get the type of the N-th argument of a +/// callable type `Func`. It supports both function types (e.g., `void(int, +/// double)`) and callable objects such as lambdas or functors. +/// +/// \tparam Func The callable type from which to extract the argument type. +/// \tparam N The index of the argument to retrieve. +/// +/// Example: +/// using Func = void(int, double, const char*); +/// static_assert(std::is_same::type, int>::value, +/// "Unexpected type"); template struct nth_argument_type { template static auto diff --git a/clang/test/dpct/function_pointer.cu b/clang/test/dpct/function_pointer.cu index a51ea4eb2496..ba3b7dfbe9c2 100644 --- a/clang/test/dpct/function_pointer.cu +++ b/clang/test/dpct/function_pointer.cu @@ -67,7 +67,7 @@ void foo() { cudaMemcpy(d_A, h_A, size, cudaMemcpyHostToDevice); cudaMemcpy(d_B, h_B, size, cudaMemcpyHostToDevice); -// CHECK: fpt fp = dpct::wrapper_register(vectorAdd_wrapper); +// CHECK: fpt fp = dpct::wrapper_register(vectorAdd_wrapper).get(); // CHECK: dpct::kernel_launch::launch(fp, 1, 10, 0, 0, d_A, d_B, d_C, N); fpt fp = vectorAdd; fp<<<1, 10>>>(d_A, d_B, d_C, N); @@ -172,9 +172,10 @@ void goo(fpt p) { template void hoo() { - // CHECK: fpt a = dpct::wrapper_register(vectorTemplateAdd_wrapper); + // CHECK: fpt a = dpct::wrapper_register(vectorTemplateAdd_wrapper).get(); fpt a = vectorTemplateAdd; - // CHECK: goo(dpct::wrapper_register), 0>::type>(vectorTemplateAdd_wrapper)); + goo(a); + // CHECK: goo(dpct::wrapper_register), 0>::type>(vectorTemplateAdd_wrapper).get()); goo(vectorTemplateAdd); } diff --git a/clang/test/dpct/launch-kernel-cooperative-usm.cu b/clang/test/dpct/launch-kernel-cooperative-usm.cu index 249391c691e0..ef21fe2038cb 100644 --- a/clang/test/dpct/launch-kernel-cooperative-usm.cu +++ b/clang/test/dpct/launch-kernel-cooperative-usm.cu @@ -23,7 +23,7 @@ __global__ void kernel(int *d, cudaTextureObject_t tex) { int gtid = blockIdx.x * blockDim.x + threadIdx.x; tex1D(d + gtid, tex, gtid); } - +// CHECK: // Auto generated SYCL kernel wrapper used to migration kernel function pointer. // CHECK: void kernel_wrapper(int * d ,dpct::image_wrapper_base_p tex) { // CHECK: sycl::queue queue = *dpct::kernel_launch::_que; // CHECK: unsigned int localMemSize = dpct::kernel_launch::_local_mem_size; @@ -97,7 +97,7 @@ int main() { // CHECK-NEXT: }); cudaLaunchCooperativeKernel((const void *)&template_kernel, dim3(16), dim3(16), args, 32, stream); - // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper); + // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get(); void *kernel_func = (void *)&kernel; // CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0); diff --git a/clang/test/dpct/launch-kernel-cooperative.cu b/clang/test/dpct/launch-kernel-cooperative.cu index 3e197220e1b1..fb6f5d1018de 100644 --- a/clang/test/dpct/launch-kernel-cooperative.cu +++ b/clang/test/dpct/launch-kernel-cooperative.cu @@ -78,7 +78,7 @@ int main() { // CHECK-NEXT: }); // CHECK-NEXT: }); cudaLaunchCooperativeKernel((const void *)&template_kernel, dim3(16), dim3(16), args, 32, stream); - // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper); + // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get(); void *kernel_func = (void *)&kernel; // CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0); diff --git a/clang/test/dpct/launch-kernel-usm.cu b/clang/test/dpct/launch-kernel-usm.cu index 12ea3c073cd8..28ef55070ace 100644 --- a/clang/test/dpct/launch-kernel-usm.cu +++ b/clang/test/dpct/launch-kernel-usm.cu @@ -76,13 +76,13 @@ int main() { // CHECK-NEXT: }); // CHECK-NEXT: }); cudaLaunchKernel((const void *)&template_kernel, dim3(16), dim3(16), args, 32, stream); - // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper); + // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get(); void *kernel_func = (void *)&kernel; // CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0); cudaLaunchKernel(kernel_func, dim3(16), dim3(16), args, 0, 0); void *kernel_array[100]; - // CHECK: kernel_array[10] = (void *)dpct::wrapper_register(&kernel_wrapper); + // CHECK: kernel_array[10] = (void *)dpct::wrapper_register(&kernel_wrapper).get(); kernel_array[10] = (void *)&kernel; // CHECK: dpct::kernel_launch::launch(kernel_array[10], dpct::dim3(16), dpct::dim3(16), args, 0, 0); cudaLaunchKernel(kernel_array[10], dim3(16), dim3(16), args, 0, 0); diff --git a/clang/test/dpct/launch-kernel.cu b/clang/test/dpct/launch-kernel.cu index 1a392ffb60c3..0e806f41f804 100644 --- a/clang/test/dpct/launch-kernel.cu +++ b/clang/test/dpct/launch-kernel.cu @@ -76,7 +76,7 @@ int main() { // CHECK-NEXT: }); // CHECK-NEXT: }); cudaLaunchKernel((const void *)&template_kernel, dim3(16), dim3(16), args, 32, stream); - // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper); + // CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get(); void *kernel_func = (void *)&kernel; // CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0); cudaLaunchKernel(kernel_func, dim3(16), dim3(16), args, 0, 0);