Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][mlir] Add support for translating task_reduction to LLVMIR #120957

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

NimishMishra
Copy link
Contributor

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: None (NimishMishra)

Changes

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.


Full diff: https://github.com/llvm/llvm-project/pull/120957.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+250-10)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..4b9f85e8baf468 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+
+
+   auto getReductionSyms() { return getTaskReductionSyms(); }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..7126a0803189a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getThreadLimit())
       result = todo("thread_limit");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
-  };
   auto checkUntied = [&todo](auto op, LogicalResult &result) {
     if (op.getUntied())
       result = todo("untied");
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkInReduction(op, result);
         checkPriority(op, result);
       })
-      .Case([&](omp::TaskgroupOp op) {
-        checkAllocate(op, result);
-        checkTaskReduction(op, result);
-      })
+      .Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
         checkNowait(op, result);
@@ -1787,6 +1779,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -1794,9 +2018,25 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
+  LogicalResult bodyGenStatus = success();
+  // Setup for `task_reduction`
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
 
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.restoreIP(*afterIP);
-  return success();
+  return bodyGenStatus;
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8ae795ec1ec6b0..a0774d859eecf6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -451,34 +451,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error@below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
   // expected-error@below {{not yet implemented: omp.taskloop}}
   // expected-error@below {{LLVM Translation failed for operation: omp.taskloop}}

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-mlir-llvm

Author: None (NimishMishra)

Changes

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.


Full diff: https://github.com/llvm/llvm-project/pull/120957.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+250-10)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..4b9f85e8baf468 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+
+
+   auto getReductionSyms() { return getTaskReductionSyms(); }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..7126a0803189a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getThreadLimit())
       result = todo("thread_limit");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
-  };
   auto checkUntied = [&todo](auto op, LogicalResult &result) {
     if (op.getUntied())
       result = todo("untied");
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkInReduction(op, result);
         checkPriority(op, result);
       })
-      .Case([&](omp::TaskgroupOp op) {
-        checkAllocate(op, result);
-        checkTaskReduction(op, result);
-      })
+      .Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
         checkNowait(op, result);
@@ -1787,6 +1779,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -1794,9 +2018,25 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
+  LogicalResult bodyGenStatus = success();
+  // Setup for `task_reduction`
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
 
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.restoreIP(*afterIP);
-  return success();
+  return bodyGenStatus;
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8ae795ec1ec6b0..a0774d859eecf6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -451,34 +451,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error@below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
   // expected-error@below {{not yet implemented: omp.taskloop}}
   // expected-error@below {{LLVM Translation failed for operation: omp.taskloop}}

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-flang-openmp

Author: None (NimishMishra)

Changes

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.


Full diff: https://github.com/llvm/llvm-project/pull/120957.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+250-10)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..4b9f85e8baf468 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+
+
+   auto getReductionSyms() { return getTaskReductionSyms(); }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..7126a0803189a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getThreadLimit())
       result = todo("thread_limit");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
-  };
   auto checkUntied = [&todo](auto op, LogicalResult &result) {
     if (op.getUntied())
       result = todo("untied");
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkInReduction(op, result);
         checkPriority(op, result);
       })
-      .Case([&](omp::TaskgroupOp op) {
-        checkAllocate(op, result);
-        checkTaskReduction(op, result);
-      })
+      .Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
         checkNowait(op, result);
@@ -1787,6 +1779,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -1794,9 +2018,25 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
+  LogicalResult bodyGenStatus = success();
+  // Setup for `task_reduction`
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
 
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.restoreIP(*afterIP);
-  return success();
+  return bodyGenStatus;
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8ae795ec1ec6b0..a0774d859eecf6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -451,34 +451,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error@below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
   // expected-error@below {{not yet implemented: omp.taskloop}}
   // expected-error@below {{LLVM Translation failed for operation: omp.taskloop}}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants