-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics #120956
[MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics #120956
Conversation
@llvm/pr-subscribers-mlir-nvgpu @llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis PR updates the WgmmaFenceAlignedOp, WgmmaGroupSyncAlignedOp, and WgmmaWaitGroupSyncOp Ops in the NVVM Dialect to lower to the corresponding intrinsics instead of inline-ptx. The existing test under Conversion/NVVMToLLVM is updated to check for the new patterns and separate tests are added under Target/LLVMIR to verify the lowered intrinsics. Full diff: https://github.com/llvm/llvm-project/pull/120956.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 530135b912b9e6..a2d2102b59dece 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2130,7 +2130,7 @@ def NVVM_CpAsyncBulkTensorReduceOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2139,12 +2139,12 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence)
}];
let assemblyFormat = "attr-dict";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned);
}];
}
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let description = [{
@@ -2152,21 +2152,21 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.a
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned);
}];
}
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
- let arguments = (ins I32Attr:$group);
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
+ let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
Signal the completion of a preceding warpgroup operation.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group));
}];
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 66b736c18718f3..84ea55ceb5acc2 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -266,19 +266,17 @@ func.func @wgmma_execute() {
nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 0
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
- // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
+ // CHECK: nvvm.wgmma.fence.aligned
+ // CHECK: nvvm.wgmma.commit.group.sync.aligned
+ // CHECK: nvvm.wgmma.wait.group.sync.aligned 0
nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 5
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
- // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
+ // CHECK: nvvm.wgmma.fence.aligned
+ // CHECK: nvvm.wgmma.commit.group.sync.aligned
+ // CHECK: nvvm.wgmma.wait.group.sync.aligned 5
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 6a32190694b470..b69d77496351c1 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -714,3 +714,29 @@ llvm.func @nvvm_breakpoint() {
nvvm.breakpoint
llvm.return
}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_fence_aligned
+llvm.func @nvvm_wgmma_fence_aligned() {
+ // CHECK: call void @llvm.nvvm.wgmma.fence.sync.aligned()
+ nvvm.wgmma.fence.aligned
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_commit_group_aligned
+llvm.func @nvvm_wgmma_commit_group_aligned() {
+ // CHECK: call void @llvm.nvvm.wgmma.commit_group.sync.aligned()
+ nvvm.wgmma.commit.group.sync.aligned
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_wait_group_aligned
+llvm.func @nvvm_wgmma_wait_group_aligned() {
+ // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
+ nvvm.wgmma.wait.group.sync.aligned 0
+ // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 20)
+ nvvm.wgmma.wait.group.sync.aligned 20
+ llvm.return
+}
|
This patch updates the WgmmaFenceAlignedOp, WgmmaGroupSyncAlignedOp, and WgmmaWaitGroupSyncOp Ops in the NVVM Dialect to lower to the corresponding intrinsics instead of inline-ptx. The existing test under Conversion/NVVMToLLVM is updated to check for the new patterns and separate tests are added under Target/LLVMIR to verify the lowered intrinsics.
39f4342
to
21a06ab
Compare
Changes LGTM |
@grypp , Could you please help with a review? |
Merging this per offline request from @Wolfram70 |
This PR updates the WgmmaFenceAlignedOp, WgmmaGroupSyncAlignedOp, and WgmmaWaitGroupSyncOp Ops in the NVVM Dialect to lower to the corresponding intrinsics instead of inline-ptx.
The existing test under Conversion/NVVMToLLVM is updated to check for the new patterns and separate tests are added under Target/LLVMIR to verify the lowered intrinsics.