Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: intwanghao <[email protected]>
  • Loading branch information
intwanghao committed Dec 18, 2024
1 parent 2b73dcb commit 45c88ea
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 21 deletions.
2 changes: 2 additions & 0 deletions clang/lib/DPCT/AnalysisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5215,6 +5215,8 @@ void DeviceFunctionDecl::insertWrapper() {
<< "nd_range<3> &nr, unsigned int localMemSize, void "
"**kernelParams, void **extra)";
} else {
Printer.line("// Auto generated SYCL kernel wrapper used to migration "
"kernel function pointer.");
if (!TParamsInfo.empty()) {
Printer << "template<";
for (size_t i = 0; i < TParamsInfo.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/RulesLang/RulesLang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4513,7 +4513,7 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
}
emplaceTransformation(new InsertBeforeStmt(
E, MapNames::getDpctNamespace() + "wrapper_register" + TypeRepl + "("));
emplaceTransformation(new InsertAfterStmt(E, ")"));
emplaceTransformation(new InsertAfterStmt(E, ").get()"));
}

void KernelCallRefRule::runRule(
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/DPCT/RulesLang/RulesLang.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,14 @@ class KernelCallRule : public NamedMigrationRule<KernelCallRule> {
SourceLocation &EpilogLocation);
};

/// Migration rule for kernel function references.
/// This rule handles kernel functions that are used as function pointers.
/// For such kernel functions, a wrapper is generated with a `_wrapper`
/// postfix added to the kernel function name. Additionally, if the kernel
/// function pointer is used in a context where its original type information is
/// erased (e.g., raw pointer usage), an extra wrapper registration is required.
/// This ensures that the raw pointer is associated with the appropriate wrapper
/// and retains the necessary type information.
class KernelCallRefRule : public NamedMigrationRule<KernelCallRefRule> {
std::string getTypeRepl(const Expr *E);
template <typename T>
Expand Down
95 changes: 84 additions & 11 deletions clang/runtime/dpct-rt/include/dpct/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,30 @@ static inline void invoke_kernel_function(dpct::kernel_function &function,
localMemSize, kernelParams, extra);
}

/// Utility class for launching SYCL kernels through auto generated kernel
/// function wrapper.
/// For example:
/// A SYCL kernel function and auto generated wrapper:
/// void kernel_func(int *ptr, sycl::nd_item<3> item);
/// void kernel_func_wrapper(int *ptr) {
/// sycl::queue queue = *dpct::kernel_launch::_que;
/// unsigned int localMemSize = dpct::kernel_launch::_local_mem_size;
/// sycl::nd_range<3> nr = dpct::kernel_launch::_nr;
/// queue.parallel_for(
/// nr,
/// [=](sycl::nd_item<3> item_ct1) {
/// kernel_func(ptr, item_ct1);
/// });
/// }
/// Then launch the kernel through auto generated wrapper like:
/// typedef void(*fpt)(int *);
/// fpt fp = kernel_func_wrapper;
/// dpct::kernel_launch::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
/// device_ptr);
/// If the origin function type is erased, then need to register it first:
/// void *fp = (void *)wrapper_register(&kernel_func_wrapper);
/// dpct::kernel_launch::launch(fp, dpct::dim3(1), dpct::dim3(1), args, 0,
/// 0);
class kernel_launch {
template <typename FuncT, typename ArgSelector, std::size_t... Index>
static void launch_helper(FuncT &&func, ArgSelector &selector,
Expand All @@ -464,62 +488,111 @@ class kernel_launch {
};

public:
/// Variables for storing execution configuration.
static inline thread_local sycl::queue *_que = nullptr;
static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
static inline thread_local unsigned int _local_mem_size = 0;
/// Map for retrieving launchable functor from a raw pointer.
static inline std::map<
const void *,
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
wrapper_map = {};
kernel_function_ptr_map = {};

static void regifter_kernel_launcher(
/// Registers a kernel function pointer with a corresponding launchable
/// functor.
/// \param [in] func Pointer to the kernel function.
/// \param [in] launcher Functor to handle kernel invocation.
static void regifter_kernel_ptr(
const void *func,
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
launcher) {
wrapper_map[func] = std::move(launcher);
kernel_function_ptr_map[func] = std::move(launcher);
}

/// Launches a kernel function with arguments provided directly through
/// auto generated kernel function wrapper.
/// \tparam FuncT Type of the auto generated kernel function wrapper.
/// \tparam ArgsT Types of kernel arguments.
/// \param [in] func Pointer to the auto generated kernel function wrapper.
/// \param [in] group_range SYCL group range.
/// \param [in] local_range SYCL local range.
/// \param [in] local_mem_size The size of local memory required by the kernel
/// function.
/// \param [in] que SYCL queue used to execute kernel.
/// \param [in] args Kernel arguments.
template <typename FuncT, typename... ArgsT>
static void launch(FuncT *func, dim3 group_range, dim3 local_range,
unsigned int local_mem_size, queue_ptr que,
ArgsT... args) {
set_execution_config(group_range, local_range, local_mem_size, que);
func(args...);
}

/// Launches a kernel function through registered auto generated kernel
/// function wrapper.
/// \param [in] func Pointer to the registered auto generated kernel
/// function wrapper.
/// \param [in] group_range SYCL group range.
/// \param [in] local_range SYCL local range.
/// \param [in] args Array of pointers to kernel arguments.
/// \param [in] local_mem_size The size of local memory required by the kernel
/// function.
/// \param [in] que SYCL queue used to execute kernel.
static void launch(const void *func, dim3 group_range, dim3 local_range,
void **args, unsigned int local_mem_size, queue_ptr que) {
wrapper_map[func](group_range, local_range, args, local_mem_size, que);
kernel_function_ptr_map[func](group_range, local_range, args,
local_mem_size, que);
}

/// Launches a kernel function with packed arguments through auto generated
/// kernel function wrapper.
/// \tparam FuncT Type of the auto generated kernel function wrapper.
/// \param [in] func Pointer to the auto generated kernel function wrapper.
/// \param [in] group_range SYCL group range.
/// \param [in] local_range SYCL local range.
/// \param [in] args Array of pointers to kernel arguments.
/// \param [in] local_mem_size The size of local memory required by the kernel
/// function.
/// \param [in] que SYCL queue used to execute kernel.
template <typename FuncT>
static typename std::enable_if<std::is_function<FuncT>::value, void>::type
launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
unsigned int local_mem_size, queue_ptr que) {
static void launch(FuncT *func, dim3 group_range, dim3 local_range,
void **args, unsigned int local_mem_size, queue_ptr que) {
constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
set_execution_config(group_range, local_range, local_mem_size, que);
args_selector<p_num, p_num, FuncT> selector(args, nullptr);
launch_helper(func, selector, std::make_index_sequence<p_num>{});
}
};

/// Helper class to register and invoke kernel functions through a wrapper.
template <typename F> class wrapper_register;
template <typename Ret, typename... Args>
class wrapper_register<Ret (*)(Args...)> {
public:
typedef Ret (*FT)(Args...);
FT func;
/// Constructor to register a kernel function pointer.
/// \param [in] fp Pointer to the kernel function.
wrapper_register(FT fp) : func(fp) {
kernel_launch::regifter_kernel_launcher((void *)func, *this);
kernel_launch::regifter_kernel_ptr((void *)func, *this);
}
/// Invokes the kernel function through the stored kernel function wrapper.
/// \param [in] group_range SYCL group range.
/// \param [in] local_range SYCL local range.
/// \param [in] args Array of pointers to kernel arguments.
/// \param [in] local_mem_size The size of local memory required by the kernel
/// function.
/// \param [in] que SYCL queue used to execute kernel.
void operator()(dim3 group_range, dim3 local_range, void **args,
unsigned int local_mem_size, queue_ptr que) {
kernel_launch::launch(func, group_range, local_range, args, local_mem_size,
que);
}
/// Retrieves the original kernel function pointer.
/// \return The original kernel function pointer.
const FT &get() const noexcept { return func; }
/// Implicit conversion to the original kernel function pointer.
/// \return The original kernel function pointer.
operator FT() const noexcept { return func; }
};
/// Deduction guide for wrapper_register.
template <typename Ret, typename... Args>
wrapper_register(Ret (*)(Args...)) -> wrapper_register<Ret (*)(Args...)>;

Expand Down
11 changes: 11 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,17 @@ class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
}
};

/// \brief This struct template used to get the type of the N-th argument of a
/// callable type `Func`. It supports both function types (e.g., `void(int,
/// double)`) and callable objects such as lambdas or functors.
///
/// \tparam Func The callable type from which to extract the argument type.
/// \tparam N The index of the argument to retrieve.
///
/// Example:
/// using Func = void(int, double, const char*);
/// static_assert(std::is_same<nth_argument_type<Func, 0>::type, int>::value,
/// "Unexpected type");
template <typename Func, std::size_t N> struct nth_argument_type {
template <typename R, typename... Args>
static auto
Expand Down
7 changes: 4 additions & 3 deletions clang/test/dpct/function_pointer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void foo() {
cudaMemcpy(d_A, h_A, size, cudaMemcpyHostToDevice);
cudaMemcpy(d_B, h_B, size, cudaMemcpyHostToDevice);

// CHECK: fpt<int> fp = dpct::wrapper_register(vectorAdd_wrapper);
// CHECK: fpt<int> fp = dpct::wrapper_register(vectorAdd_wrapper).get();
// CHECK: dpct::kernel_launch::launch(fp, 1, 10, 0, 0, d_A, d_B, d_C, N);
fpt<int> fp = vectorAdd;
fp<<<1, 10>>>(d_A, d_B, d_C, N);
Expand Down Expand Up @@ -172,9 +172,10 @@ void goo(fpt<T> p) {

template <typename T>
void hoo() {
// CHECK: fpt<int> a = dpct::wrapper_register<decltype(a)>(vectorTemplateAdd_wrapper);
// CHECK: fpt<int> a = dpct::wrapper_register<decltype(a)>(vectorTemplateAdd_wrapper).get();
fpt<int> a = vectorTemplateAdd;
// CHECK: goo<T>(dpct::wrapper_register<typename dpct::nth_argument_type<decltype(goo<T>), 0>::type>(vectorTemplateAdd_wrapper));
goo<int>(a);
// CHECK: goo<T>(dpct::wrapper_register<typename dpct::nth_argument_type<decltype(goo<T>), 0>::type>(vectorTemplateAdd_wrapper).get());
goo<T>(vectorTemplateAdd);
}

Expand Down
4 changes: 2 additions & 2 deletions clang/test/dpct/launch-kernel-cooperative-usm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ __global__ void kernel(int *d, cudaTextureObject_t tex) {
int gtid = blockIdx.x * blockDim.x + threadIdx.x;
tex1D(d + gtid, tex, gtid);
}

// CHECK: // Auto generated SYCL kernel wrapper used to migration kernel function pointer.
// CHECK: void kernel_wrapper(int * d ,dpct::image_wrapper_base_p tex) {
// CHECK: sycl::queue queue = *dpct::kernel_launch::_que;
// CHECK: unsigned int localMemSize = dpct::kernel_launch::_local_mem_size;
Expand Down Expand Up @@ -97,7 +97,7 @@ int main() {
// CHECK-NEXT: });
cudaLaunchCooperativeKernel((const void *)&template_kernel<int>, dim3(16), dim3(16), args, 32, stream);

// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper);
// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get();
void *kernel_func = (void *)&kernel;

// CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0);
Expand Down
2 changes: 1 addition & 1 deletion clang/test/dpct/launch-kernel-cooperative.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ int main() {
// CHECK-NEXT: });
// CHECK-NEXT: });
cudaLaunchCooperativeKernel((const void *)&template_kernel<int>, dim3(16), dim3(16), args, 32, stream);
// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper);
// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get();
void *kernel_func = (void *)&kernel;

// CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0);
Expand Down
4 changes: 2 additions & 2 deletions clang/test/dpct/launch-kernel-usm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ int main() {
// CHECK-NEXT: });
// CHECK-NEXT: });
cudaLaunchKernel((const void *)&template_kernel<int>, dim3(16), dim3(16), args, 32, stream);
// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper);
// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get();
void *kernel_func = (void *)&kernel;
// CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0);
cudaLaunchKernel(kernel_func, dim3(16), dim3(16), args, 0, 0);

void *kernel_array[100];
// CHECK: kernel_array[10] = (void *)dpct::wrapper_register(&kernel_wrapper);
// CHECK: kernel_array[10] = (void *)dpct::wrapper_register(&kernel_wrapper).get();
kernel_array[10] = (void *)&kernel;
// CHECK: dpct::kernel_launch::launch(kernel_array[10], dpct::dim3(16), dpct::dim3(16), args, 0, 0);
cudaLaunchKernel(kernel_array[10], dim3(16), dim3(16), args, 0, 0);
Expand Down
2 changes: 1 addition & 1 deletion clang/test/dpct/launch-kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ int main() {
// CHECK-NEXT: });
// CHECK-NEXT: });
cudaLaunchKernel((const void *)&template_kernel<int>, dim3(16), dim3(16), args, 32, stream);
// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper);
// CHECK: void *kernel_func = (void *)dpct::wrapper_register(&kernel_wrapper).get();
void *kernel_func = (void *)&kernel;
// CHECK: dpct::kernel_launch::launch(kernel_func, dpct::dim3(16), dpct::dim3(16), args, 0, 0);
cudaLaunchKernel(kernel_func, dim3(16), dim3(16), args, 0, 0);
Expand Down

0 comments on commit 45c88ea

Please sign in to comment.