From a3a64d2646e6f8ef95f514fdbdad7bc34f1d51c1 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 25 Oct 2023 10:16:57 -0700 Subject: [PATCH] Adding selection condition to hal.executable.variant. (#15284) This allows for variants to declare host logic that determines whether the variant should be selected for loading. When multiple variants are available their declared conditions will be evaluated in op order along with the existing executable format match. Unfortunately the MLIR SymbolTable trait disallows multiple regions on any op holding it so a new `hal.executable.condition` region op was added that may be optionally present on any `hal.executable.variant`. Ideally we clean this up and make it an optional region but that'll need relaxing of upstream assertions like https://sourcegraph.com/github.com/llvm/llvm-project/-/blob/mlir/lib/IR/SymbolTable.cpp?L122-123 (ideally either treating region 0 as the symbol table on ops or having an interface override for selecting the region ala `getCallableRegion` such as `getSymbolTableRegion`). This removes the `hal.device.switch` op in favor of `scf.index_switch`. When we start running const expr hoisting during the HAL pipeline this should allow variant selection to be completely hoisted to initialization time (or at least memoized per device). There's decent low-hanging future work on optimizing the ranking/selection and improving `scf.index_switch` hoisting/canonicalization to make things better. --- compiler/plugins/target/CUDA/CUDATarget.cpp | 7 +- .../iree/compiler/Codegen/LLVMCPU/BUILD.bazel | 1 - .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 - .../Codegen/LLVMGPU/ConvertToLLVM.cpp | 2 +- .../Codegen/SPIRV/ConvertToSPIRVPass.cpp | 3 +- .../compiler/Codegen/Utils/LinkingUtils.cpp | 8 +- .../src/iree/compiler/Codegen/Utils/Utils.cpp | 4 +- .../test/outline_dispatch_regions.mlir | 18 +- .../Dialect/HAL/Conversion/BUILD.bazel | 1 - .../Dialect/HAL/Conversion/CMakeLists.txt | 1 - .../HAL/Conversion/HALToVM/BUILD.bazel | 1 - .../HAL/Conversion/HALToVM/CMakeLists.txt | 1 - .../HAL/Conversion/StreamToHAL/BUILD.bazel | 1 - .../HAL/Conversion/StreamToHAL/CMakeLists.txt | 1 - .../HAL/Conversion/StreamToHAL/Patterns.cpp | 66 ++-- .../Conversion/StreamToHAL/test/cmd_ops.mlir | 49 ++- .../iree/compiler/Dialect/HAL/IR/BUILD.bazel | 2 + .../compiler/Dialect/HAL/IR/CMakeLists.txt | 2 + .../compiler/Dialect/HAL/IR/HALDialect.cpp | 2 + .../compiler/Dialect/HAL/IR/HALInterfaces.td | 2 +- .../compiler/Dialect/HAL/IR/HALOpFolders.cpp | 13 +- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 366 ++++++++++++------ .../iree/compiler/Dialect/HAL/IR/HALOps.td | 187 +++++---- .../iree/compiler/Dialect/HAL/IR/HALTypes.h | 2 +- .../Dialect/HAL/IR/test/device_ops.mlir | 34 -- .../Dialect/HAL/IR/test/executable_ops.mlir | 42 ++ .../Dialect/HAL/IR/test/tensor_ops.mlir | 18 +- .../compiler/Dialect/HAL/Target/BUILD.bazel | 1 - .../Dialect/HAL/Target/CMakeLists.txt | 1 - .../Dialect/HAL/Target/ROCM/ROCMTarget.cpp | 2 +- .../Dialect/HAL/Target/TargetBackend.h | 1 - .../Target/VulkanSPIRV/VulkanSPIRVTarget.cpp | 2 +- .../HAL/Target/WebGPU/WebGPUTarget.cpp | 3 +- .../Dialect/HAL/Transforms/BUILD.bazel | 2 - .../Dialect/HAL/Transforms/CMakeLists.txt | 2 - .../Transforms/DumpExecutableBenchmarks.cpp | 2 +- .../HAL/Transforms/InlineDeviceSwitches.cpp | 175 --------- .../HAL/Transforms/MaterializeInterfaces.cpp | 53 ++- .../Transforms/MaterializeResourceCaches.cpp | 64 +-- .../Dialect/HAL/Transforms/Passes.cpp | 4 - .../compiler/Dialect/HAL/Transforms/Passes.h | 4 - .../HAL/Transforms/SubstituteExecutables.cpp | 4 +- .../Dialect/HAL/Transforms/test/BUILD.bazel | 1 - .../HAL/Transforms/test/CMakeLists.txt | 1 - .../HAL/Transforms/test/convert_to_hal.mlir | 17 +- .../test/inline_device_switches.mlir | 84 ---- .../test/materialize_resource_caches.mlir | 45 ++- .../compiler/Dialect/HAL/Utils/BUILD.bazel | 29 -- .../compiler/Dialect/HAL/Utils/CMakeLists.txt | 29 -- .../Dialect/HAL/Utils/DeviceSwitchBuilder.h | 207 ---------- .../iree/compiler/Dialect/Util/IR/UtilOps.cpp | 31 ++ .../iree/compiler/Dialect/Util/IR/UtilOps.h | 14 + .../Conversion/HALToHALInline/BUILD.bazel | 1 - .../Conversion/HALToHALInline/CMakeLists.txt | 1 - .../Conversion/StreamToHALInline/BUILD.bazel | 1 - .../StreamToHALInline/CMakeLists.txt | 1 - .../Inline/Transforms/InlineExecutables.cpp | 2 +- .../Conversion/StreamToHALLoader/BUILD.bazel | 1 - .../StreamToHALLoader/CMakeLists.txt | 1 - .../Conversion/StreamToHALLoader/Patterns.cpp | 9 +- .../vulkan/shaders/example_inline.mlir | 20 +- 61 files changed, 684 insertions(+), 966 deletions(-) delete mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir delete mode 100644 compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel delete mode 100644 compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt delete mode 100644 compiler/src/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 8d853fc9946a..605468852c3c 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -431,7 +431,7 @@ class CUDATargetBackend final : public TargetBackend { // Collect all the entry point parameters. SmallVector> workgroupSizes; SmallVector workgroupLocalMemories; - for (auto exportOp : variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { std::array workgroupSize; if (std::optional workgroupSizeAttr = exportOp.getWorkgroupSize()) { @@ -472,7 +472,7 @@ class CUDATargetBackend final : public TargetBackend { // these to match the names in their kernels. We don't support any kind of // mangling and if the user was silly enough to rely on nvcc C++ mangling // they'll have to figure that out. - for (auto exportOp : variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { entryPointNames.emplace_back(exportOp.getSymName()); } @@ -503,8 +503,7 @@ class CUDATargetBackend final : public TargetBackend { } for (auto [exportOp, workgroupSize] : - llvm::zip_equal(variantOp.getOps(), - workgroupSizes)) { + llvm::zip_equal(variantOp.getExportOps(), workgroupSizes)) { auto *llvmFunc = llvmModule->getFunction(exportOp.getName()); if (llvmFunc->isDeclaration()) continue; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 900d220d7720..7b1ab5fe13d6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -96,7 +96,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Utils", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 8d5feafc38db..d7987204f4d8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -146,7 +146,6 @@ iree_cc_library( iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Util::IR iree::compiler::Dialect::Util::Transforms iree::compiler::Utils diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index 0748dc171937..589d9dc6f8e7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -85,7 +85,7 @@ void ConvertToDynamicSharedMemory(ModuleOp moduleOp) { // Add the amount of shared memory required as an attribute. auto variantOp = moduleOp->getParentOfType(); if (variantOp != nullptr) { - for (auto exportOp : variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { exportOp->setAttr(exportOp.getWorkgroupLocalMemoryAttrName(), builder.getIndexAttr(numberOfBytes)); } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index 83de13f17dfb..3ba39a1a378b 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -179,8 +179,7 @@ struct HALInterfaceLoadConstantConverter final // TODO(#1519): this conversion should look up the entry point information // to get the total push constant count. auto variantOp = loadOp->getParentOfType(); - auto exportOps = - llvm::to_vector<1>(variantOp.getOps()); + auto exportOps = llvm::to_vector<1>(variantOp.getExportOps()); assert(exportOps.size() == 1); auto layoutAttr = exportOps.front().getLayout(); diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp index 824d8c274c0b..97d612432134 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -236,14 +236,14 @@ LogicalResult linkExecutablesInto( // Move any constant blocks that need to be preserved for future host // translation. There may be duplicates provided but they'll be cleaned // up in future passes. - for (auto constantBlockOp : llvm::make_early_inc_range( - variantOp.getOps())) { + for (auto constantBlockOp : + llvm::make_early_inc_range(variantOp.getConstantBlockOps())) { constantBlockOp->moveBefore(&*linkedTargetBuilder.getInsertionPoint()); } // Clone export ops and queue remapping ordinals and updating // symbol refs. - for (auto exportOp : variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { auto newExportOp = linkedTargetBuilder.create( exportOp.getLoc(), exportOp.getSymNameAttr(), @@ -290,7 +290,7 @@ LogicalResult linkExecutablesInto( replaceEntryPointUses(moduleOp, symbolReplacements); // Remove if we didn't add anything. - if (linkedTargetOp.getOps().empty()) { + if (linkedTargetOp.getExportOps().empty()) { linkedTargetOp.erase(); linkedExecutableOp.erase(); } diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index c443639ddb44..bbfc47c850c9 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -39,7 +39,7 @@ FailureOr getEntryPoint(func::FuncOp funcOp) { if (!variantOp) return failure(); - for (auto op : variantOp.getOps()) { + for (auto op : variantOp.getExportOps()) { if (op.getSymName() == funcOp.getName()) { return op; } @@ -66,7 +66,7 @@ llvm::StringMap getAllEntryPoints(ModuleOp module) { auto variantOp = module->getParentOfType(); llvm::StringMap exportOps; - for (auto op : variantOp.getOps()) { + for (auto op : variantOp.getExportOps()) { exportOps[op.getSymName()] = op; } return exportOps; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir index 56f05a5ffc76..03462c513dd1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir @@ -202,6 +202,15 @@ func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32 // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>] // CHECK-SAME: } : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1 %result = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1 + // Translates the workload (%x and %y captured above) into an XYZ workgroup + // count, optionally using device information. + count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) { + // Shows how device queries can be used when computing the workgroup count. + // The device is the one used at runtime. + %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32 + %z = arith.index_cast %z_i32 : i32 to index + hal.return %x_capture, %y_capture, %z : index, index, index + } // Must match the external definition. layout(#hal.pipeline.layout, %arg1: tensor<8xi32>, %arg2: i32 #hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>], #hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>] }>) - // Translates the workload (%x and %y captured above) into an XYZ workgroup - // count, optionally using device information. - count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) { - // Shows how device queries can be used when computing the workgroup count. - // The device is the one used at runtime. - %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32 - %z = arith.index_cast %z_i32 : i32 to index - hal.return %x_capture, %y_capture, %z : index, index, index - } // CHECK: return %[[RESULT]] return %result : tensor<8xi32> } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel index 292353f59e33..d87a54f88c18 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel @@ -25,7 +25,6 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Dialect/HAL/IR", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt index c452ba50540b..04da911cf5b0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt @@ -26,7 +26,6 @@ iree_cc_library( MLIRMemRefDialect MLIRTransforms iree::compiler::Dialect::HAL::IR - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Util::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel index 9f7dba7e9737..c630dc4b55cd 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel @@ -32,7 +32,6 @@ iree_compiler_cc_library( deps = [ "//compiler/src/iree/compiler/Dialect/HAL:hal_imports", "//compiler/src/iree/compiler/Dialect/HAL/IR", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt index 421ef74868bb..5a5cb8ac005b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt @@ -34,7 +34,6 @@ iree_cc_library( MLIRPass MLIRTransforms iree::compiler::Dialect::HAL::IR - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::HAL::hal_imports iree::compiler::Dialect::Util::IR iree::compiler::Dialect::VM::Conversion diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel index c9803ce4809d..52653fbd2b83 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel @@ -25,7 +25,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/HAL/Target", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt index 1fe1cbe23fa9..92dd0c674e4e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt @@ -29,7 +29,6 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Target - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::IR PUBLIC diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index d0a809e7dc27..46e21e6e4bce 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -9,7 +9,6 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" -#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" @@ -920,16 +919,6 @@ struct CmdCollectiveOpPattern } }; -// Returns a hal.device.switch match expression that selects the given export. -static Attribute -getExportConditionAttr(IREE::HAL::ExecutableExportOp exportOp) { - // TODO(benvanik): customizable selection logic. Today this just checks - // whether the variant target is supported but we can also allow - // specialization of entry points based on dispatch site parameters. - auto variantOp = exportOp->getParentOfType(); - return variantOp.getTarget().getMatchExpression(); -} - struct CmdDispatchOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -942,42 +931,65 @@ struct CmdDispatchOpPattern // Get the device handle we're executing against in this execution region. // Note that this is a dynamic value: we have to treat the device as unknown // here. - auto device = rewriter.create( + auto deviceValue = rewriter.create( loc, rewriter.getType(), commandBuffer); - // Ask each target backend to record their dispatch logic. - IREE::HAL::DeviceSwitchRewriter switchRewriter(loc, - /*resultTypes=*/TypeRange{}, - device, rewriter); + // Prepare for variant switch table by gathering the conditions selecting + // each variant. + SmallVector caseIndices; + SmallVector> + caseExportOps; dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPointAttr) { // NOTE: slow lookup! auto exportOp = SymbolTable::lookupNearestSymbolFrom( dispatchOp, entryPointAttr); assert(exportOp && "dispatch target export not found"); + caseIndices.push_back(caseIndices.size()); + caseExportOps.push_back(std::make_pair(entryPointAttr, exportOp)); + }); - // Setup the case condition for the entry point. - auto *caseRegion = - switchRewriter.addConditionRegion(getExportConditionAttr(exportOp)); - auto &entryBlock = caseRegion->front(); - auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock); + // Select the variant index. + Value selectedIndex = buildIfElseTree( + loc, caseExportOps.size(), + [&](Location loc, size_t i, OpBuilder &builder) { + auto exportOp = caseExportOps[i].second; + auto variantOp = + exportOp->getParentOfType(); + return variantOp.buildCondition(deviceValue, rewriter); + }, + rewriter); + + // Allow each variant to define how it is dispatched. + auto switchOp = rewriter.replaceOpWithNewOp( + dispatchOp, TypeRange{}, selectedIndex, caseIndices, + caseIndices.size()); + for (size_t i = 0; i < caseExportOps.size(); ++i) { + auto entryPointAttr = caseExportOps[i].first; + auto exportOp = caseExportOps[i].second; + auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); // Record push constants and buffer bindings. - recordParameters(loc, device, commandBuffer, dispatchOp, adaptor, + recordParameters(loc, deviceValue, commandBuffer, dispatchOp, adaptor, exportOp.getLayout(), caseBuilder); // Dispatch with a target-specific workgroup count. auto caseWorkgroupCount = exportOp.calculateWorkgroupCount( - loc, device, adaptor.getWorkload(), caseBuilder); + loc, deviceValue, adaptor.getWorkload(), caseBuilder); caseBuilder.create( loc, commandBuffer, entryPointAttr, caseWorkgroupCount[0], caseWorkgroupCount[1], caseWorkgroupCount[2]); - caseBuilder.create(loc); - }); - switchRewriter.build(); + caseBuilder.create(loc); + } + + // Fallback for no available variant. Today we just no-op as executable + // loading should have already failed. + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + defaultBuilder.create(loc); - rewriter.eraseOp(dispatchOp); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir index 23eac6d27ac0..d9e68736e84f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir @@ -164,9 +164,10 @@ func.func @cmdExecute(%arg0: !stream.resource, %arg1: index, %arg2: ! // ----- -#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> +#executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> +#executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> #device_target_cpu = #hal.device.target<"llvm-cpu", { - executable_targets = [#executable_target_embedded_elf_x86_64] + executable_targets = [#executable_target_aarch64, #executable_target_x86_64] }> #pipeline_layout = #hal.pipeline.layout, %arg1: index, %arg2: ! ]> ]> hal.executable private @ex { - hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { + hal.executable.variant public @aarch64 target(#executable_target_aarch64) { + hal.executable.condition(%device: !hal.device) -> i1 { + %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1 + hal.return %selected : i1 + } + hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + hal.return %0, %c1, %c1 : index, index, index + } + builtin.module { + // Opaque at this point (in some target-specific dialects). + } + } + hal.executable.variant public @x86_64 target(#executable_target_x86_64) { hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { translation_info = #iree_codegen.translation_info } { @@ -203,9 +221,19 @@ func.func @cmdDispatch(%arg0: !stream.resource, %arg1: index, %arg2: %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource{%arg1}, %arg2 as %arg5: !stream.resource{%arg3}) { - // Switch for each executable variant: - // CHECK: hal.device.switch - // CHECK-NEXT: #hal.device.match.executable.format<"embedded-elf-x86_64"> + // Switch for each executable variant by checking conditions and ranking: + // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer> + // CHECK-DAG: %{{.+}}, %[[AARCH64_FORMAT:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64") + // CHECK-DAG: %[[AARCH64_FEATURE:.+]] = scf.execute_region -> i1 { + // CHECK-NEXT: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature") + // CHECK-NEXT: scf.yield %[[FEATURE]] + // CHECK-NEXT: } + // CHECK-DAG: %[[AARCH64_SELECTED:.+]] = arith.andi %[[AARCH64_FORMAT]], %[[AARCH64_FEATURE]] + // CHECK-DAG: %{{.+}}, %[[X86_64_SELECTED:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64") + // CHECK: %[[VARIANT1:.+]] = arith.select %[[X86_64_SELECTED]], %c1 + // CHECK: %[[VARIANT0:.+]] = arith.select %[[AARCH64_SELECTED]], %c0{{.+}}, %[[VARIANT1]] + // CHECK: scf.index_switch %[[VARIANT0]] + // CHECK-NEXT: case 0 { // Cache queries: // CHECK-DAG: %[[LAYOUT:.+]] = hal.pipeline_layout.lookup {{.+}} layout(#pipeline_layout) @@ -230,9 +258,14 @@ func.func @cmdDispatch(%arg0: !stream.resource, %arg1: index, %arg2: // Dispatch: // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] - // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch) + // CHECK-SAME: target(@ex::@aarch64::@dispatch) // CHECK-SAME: workgroups([%[[X]], %[[YZ]], %[[YZ]]]) - stream.cmd.dispatch @ex::@embedded_elf_x86_64::@dispatch[%c1, %c2, %c3](%c4_i32, %c5_i32 : i32, i32) { + + // Other variant, when selected: + // CHECK: case 1 { + // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] + // CHECK-SAME: target(@ex::@x86_64::@dispatch) + stream.cmd.dispatch {@ex::@aarch64::@dispatch, @ex::@x86_64::@dispatch}[%c1, %c2, %c3](%c4_i32, %c5_i32 : i32, i32) { ro %arg4[%c0 for %c128] : !stream.resource{%arg1}, wo %arg5[%c0 for %c128] : !stream.resource{%arg3} } attributes { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel index e5a47aca74e5..4c543db7eef5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel @@ -81,6 +81,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -105,6 +106,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TransformUtils", ], ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 8d6c8926326c..64d0fb1974e9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -47,6 +47,7 @@ iree_cc_library( MLIRIR MLIRMemRefDialect MLIRParser + MLIRSCFDialect MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils @@ -73,6 +74,7 @@ iree_cc_library( MLIRIR MLIRMemRefDialect MLIRParser + MLIRSCFDialect MLIRTransformUtils iree::compiler::Dialect::HAL::Conversion::HALToVM iree::compiler::Dialect::HAL::hal_imports diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp index 7d7655267d23..decdf937717f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -127,6 +128,7 @@ class HALToVMConversionInterface : public VMConversionDialectInterface { HALDialect::HALDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) { context->loadDialect(); + context->loadDialect(); context->loadDialect(); registerAttributes(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td index cdf89aecd09d..b1df6fffa402 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td @@ -25,7 +25,7 @@ def HAL_MatchAttrInterface : attribute is true for the given value. }], "Value", "buildConditionExpression", - (ins "Location":$loc, "Value":$value, "OpBuilder":$builder) + (ins "Location":$loc, "Value":$device, "OpBuilder":$builder) >, ]; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 6d52a2009986..dfc893d297e5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -350,16 +350,6 @@ void CommandBufferPushDescriptorSetOp::getCanonicalizationPatterns( results.insert(context); } -//===----------------------------------------------------------------------===// -// hal.device.switch -//===----------------------------------------------------------------------===// - -// TODO(benvanik): fold conditions with the same IR tree. -// TODO(benvanik): remove duplicate conditions. -// TODO(benvanik): fold condition expressions (any(always, ...) -> always, etc). -// TODO(benvanik): completely replace switches with just one always block. -// TODO(benvanik): remove conditions with no side-effects. - //===----------------------------------------------------------------------===// // hal.device.match.id //===----------------------------------------------------------------------===// @@ -682,8 +672,7 @@ struct MergeExecutableConstantBlocks using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExecutableVariantOp variantOp, PatternRewriter &rewriter) const override { - auto blockOps = - llvm::to_vector(variantOp.getOps()); + auto blockOps = llvm::to_vector(variantOp.getConstantBlockOps()); if (blockOps.size() <= 1) { return rewriter.notifyMatchFailure(variantOp, "not enough blocks to merge"); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 4157fe16861e..018cc77051ad 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -113,6 +114,146 @@ static void printDescriptorSetBindings(OpAsmPrinter &p, Operation *op, p.printNewline(); } +//===----------------------------------------------------------------------===// +// custom($body) +//===----------------------------------------------------------------------===// + +static FunctionType getTargetConditionRegionType(MLIRContext *context) { + return FunctionType::get(context, + { + IREE::HAL::DeviceType::get(context), + }, + { + IntegerType::get(context, 1), + }); +} + +static LogicalResult verifyTargetConditionRegion(Operation *op, + Region ®ion) { + // Ignore if empty. + if (region.empty()) + return success(); + + // Verify region takes a !hal.device. + if (region.getNumArguments() != 1 || + !isa(region.getArgumentTypes().front())) { + return op->emitOpError() + << "target condition region must take a !hal.device"; + } + + // Verify i1 return. + for (auto returnOp : region.getOps()) { + if (returnOp.getNumOperands() != 1) { + return returnOp.emitOpError() + << "target condition region must return a single i1 result"; + } + for (auto returnType : returnOp.getOperandTypes()) { + if (!returnType.isInteger(1)) { + return returnOp.emitOpError() + << "target condition region must return a single i1 result"; + } + } + } + + return success(); +} + +static ParseResult parseTargetConditionRegion(OpAsmParser &parser, + Region &body) { + SmallVector args; + if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren, + /*allowType=*/true, + /*allowAttrs=*/true))) { + return failure(); + } + + SmallVector returnTypes; + if (failed(parser.parseArrowTypeList(returnTypes))) { + return failure(); + } + if (returnTypes.size() != 1 || + !llvm::all_of(returnTypes, [](Type type) { return type.isInteger(1); })) { + return parser.emitError(parser.getCurrentLocation()) + << "target condition region must return one i1"; + } + + return parser.parseRegion(body, args, /*enableNameShadowing=*/false); +} + +static void printTargetConditionRegion(OpAsmPrinter &p, Operation *op, + Region &body) { + if (body.empty()) + return; + p << "("; + llvm::interleaveComma(body.getArguments(), p, + [&](BlockArgument arg) { p.printRegionArgument(arg); }); + p << ")"; + p.printArrowTypeList(TypeRange{IntegerType::get(body.getContext(), 1)}); + p << " "; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +//===----------------------------------------------------------------------===// +// custom($targets, $objects, $target_regions) +//===----------------------------------------------------------------------===// + +static ParseResult parseConditionalTargetRegions( + OpAsmParser &parser, ArrayAttr &targetsAttr, ArrayAttr &objectsAttr, + SmallVectorImpl> &targetRegions) { + auto builder = parser.getBuilder(); + SmallVector targetAttrs; + SmallVector objectsAttrs; + do { + IREE::HAL::ExecutableTargetAttr targetAttr; + if (failed(parser.parseAttribute(targetAttr))) + return failure(); + targetAttrs.push_back(targetAttr); + std::unique_ptr targetRegion = std::make_unique(); + if (succeeded(parser.parseOptionalKeyword("if"))) { + if (failed(parseTargetConditionRegion(parser, *targetRegion))) + return failure(); + } + targetRegions.emplace_back(std::move(targetRegion)); + if (failed(parser.parseEqual())) + return failure(); + ArrayAttr targetObjectsAttr; + if (failed(parser.parseAttribute(targetObjectsAttr))) + return failure(); + objectsAttrs.push_back(targetObjectsAttr); + } while (succeeded(parser.parseOptionalComma())); + targetsAttr = builder.getArrayAttr(targetAttrs); + objectsAttr = builder.getArrayAttr(objectsAttrs); + return success(); +} + +static void +printConditionalTargetRegions(OpAsmPrinter &p, Operation *op, + ArrayAttr targetsAttr, ArrayAttr objectsAttr, + MutableArrayRef targetRegions) { + p.increaseIndent(); + p.printNewline(); + llvm::interleave( + llvm::zip_equal(targetsAttr.getAsRange(), + objectsAttr.getAsRange(), targetRegions), + [&](auto it) { + auto [targetAttr, targetObjectsAttr, targetRegion] = it; + p.printAttribute(targetAttr); + if (!targetRegion.empty()) { + p << " if"; + printTargetConditionRegion(p, op, targetRegion); + } + p << " = "; + p.printAttribute(targetObjectsAttr); + }, + [&]() { + p << ","; + p.printNewline(); + }); + p.decreaseIndent(); + p.printNewline(); +} + //===----------------------------------------------------------------------===// // custom($body) //===----------------------------------------------------------------------===// @@ -161,12 +302,8 @@ static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, if (body.empty()) return; p << "("; - auto args = body.getArguments(); - for (unsigned i = 0; i < args.size(); ++i) { - if (i > 0) - p << ", "; - p.printRegionArgument(args[i]); - } + llvm::interleaveComma(body.getArguments(), p, + [&](BlockArgument arg) { p.printRegionArgument(arg); }); p << ")"; Type indexType = IndexType::get(body.getContext()); p.printArrowTypeList(TypeRange{indexType, indexType, indexType}); @@ -704,113 +841,6 @@ LogicalResult DeviceQueryOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// hal.device.switch -//===----------------------------------------------------------------------===// - -void DeviceSwitchOp::build(OpBuilder &builder, OperationState &state, - TypeRange resultTypes, Value device, - ArrayRef conditions, - ArrayRef attributes) { - state.addOperands({device}); - state.addAttribute("conditions", builder.getArrayAttr(conditions)); - for (size_t i = 0; i < conditions.size(); ++i) { - state.addRegion(); - } - state.addTypes(resultTypes); - state.addAttributes(attributes); -} - -ParseResult DeviceSwitchOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand device; - Type deviceType; - if (failed(parser.parseLess()) || failed(parser.parseOperand(device)) || - failed(parser.parseColonType(deviceType)) || - failed(parser.resolveOperand(device, deviceType, result.operands)) || - failed(parser.parseGreater()) || - failed(parser.parseOptionalArrowTypeList(result.types))) { - return failure(); - } - - // Parses each switch condition attribute and region, like: - // #hal.device.match.id<"vulkan-v1.?-*"> { - // hal.return %c1 : i32 - // }, ... - SmallVector conditionAttrs; - do { - Attribute conditionAttr; - NamedAttrList dummyAttrs; - if (failed(parser.parseAttribute(conditionAttr, "condition", dummyAttrs))) { - return failure(); - } - conditionAttrs.push_back(conditionAttr); - SmallVector regionArgs; - auto *regionBody = result.addRegion(); - if (failed(parser.parseRegion(*regionBody, regionArgs))) { - return failure(); - } - } while (succeeded(parser.parseOptionalComma())); - result.addAttribute("conditions", - ArrayAttr::get(result.getContext(), conditionAttrs)); - - if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) { - return failure(); - } - return success(); -} - -void DeviceSwitchOp::print(OpAsmPrinter &p) { - Operation *op = getOperation(); - p << "<"; - p.printOperand(getDevice()); - p << " : "; - p.printType(getDevice().getType()); - p << ">"; - p.printOptionalArrowTypeList(getResultTypes()); - p << "\n"; - p.getStream().indent(4); - interleave( - llvm::zip_equal(getConditions(), getConditionRegions()), - [&](std::tuple it) { - auto &conditionAttr = std::get<0>(it); - auto &conditionRegion = std::get<1>(it); - p.printAttribute(conditionAttr); - p << " "; - p.printRegion(conditionRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); - }, - [&]() { - p << ",\n"; - p.getStream().indent(4); - }); - p.printOptionalAttrDictWithKeyword(op->getAttrs(), - /*elidedAttrs=*/{"conditions"}); -} - -LogicalResult DeviceSwitchOp::verify() { - DeviceSwitchOp op = *this; - if (op.getConditions().size() != op.getConditionRegions().size()) { - return op.emitOpError() << "requires conditions and regions be matched 1:1"; - } else if (op.getConditionRegions().empty()) { - return op.emitOpError() << "requires at least one condition"; - } - for (auto ®ion : op.getConditionRegions()) { - for (auto &block : region) { - if (auto returnOp = - dyn_cast_or_null(block.getTerminator())) { - if (!std::equal(returnOp.getOperandTypes().begin(), - returnOp.getOperandTypes().end(), - op.getResultTypes().begin())) { - return op.emitOpError() - << "requires all regions return the same types"; - } - } - } - } - return success(); -} - //===----------------------------------------------------------------------===// // hal.device.queue.* //===----------------------------------------------------------------------===// @@ -857,6 +887,21 @@ LogicalResult DeviceQueueExecuteOp::verify() { return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); } +//===----------------------------------------------------------------------===// +// hal.executable.source +//===----------------------------------------------------------------------===// + +LogicalResult ExecutableSourceOp::verify() { + ExecutableSourceOp op = *this; + + auto conditionOps = getOps(); + if (llvm::range_size(conditionOps) > 1) + return op.emitOpError() + << "only one condition op is allowed in an executable"; + + return success(); +} + //===----------------------------------------------------------------------===// // hal.executable //===----------------------------------------------------------------------===// @@ -1102,9 +1147,19 @@ void ExecutableVariantOp::build(OpBuilder &builder, OperationState &state, state.addAttribute("target", target); } +LogicalResult ExecutableVariantOp::verify() { + ExecutableVariantOp op = *this; + + auto conditionOps = getOps(); + if (llvm::range_size(conditionOps) > 1) + return op.emitOpError() << "only one condition op is allowed in a variant"; + + return success(); +} + DenseMap ExecutableVariantOp::gatherConstantOrdinals() { DenseMap map; - for (auto blockOp : getOps()) { + for (auto blockOp : getConstantBlockOps()) { int baseCount = map.size(); for (auto [i, keyAttr] : llvm::enumerate(blockOp.getKeys())) { map.try_emplace(keyAttr, baseCount + i); @@ -1113,6 +1168,89 @@ DenseMap ExecutableVariantOp::gatherConstantOrdinals() { return map; } +Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) { + // Base case dependent on target information. + auto matchAttr = + cast(getTarget().getMatchExpression()); + auto selected = matchAttr.buildConditionExpression(getLoc(), device, builder); + + // Factor in variant condition region, if any. + auto conditionOp = getConditionOp(); + if (conditionOp) { + auto regionOp = builder.create(conditionOp.getLoc(), + builder.getI1Type()); + + IRMapping mapper; + mapper.map(conditionOp.getRegion().getArgument(0), device); + conditionOp.getRegion().cloneInto(®ionOp.getRegion(), mapper); + + for (auto returnOp : + llvm::make_early_inc_range(regionOp.getOps())) { + OpBuilder(returnOp).create(returnOp.getLoc(), + returnOp.getOperands()); + returnOp.erase(); + } + + selected = builder.create(getLoc(), selected, + regionOp.getResult(0)); + } + + return selected; +} + +//===----------------------------------------------------------------------===// +// hal.executable.condition +//===----------------------------------------------------------------------===// + +LogicalResult ExecutableConditionOp::verify() { + ExecutableConditionOp op = *this; + return verifyTargetConditionRegion(op, op.getBody()); +} + +void ExecutableConditionOp::build(OpBuilder &builder, OperationState &result, + ArrayRef attrs) { + result.addAttribute( + "function_type", + TypeAttr::get(getTargetConditionRegionType(builder.getContext()))); + result.addRegion(); + result.attributes.append(attrs.begin(), attrs.end()); +} + +ParseResult ExecutableConditionOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parseTargetConditionRegion(parser, *result.addRegion())) + return failure(); + result.addAttribute( + "function_type", + TypeAttr::get(getTargetConditionRegionType(parser.getContext()))); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + return success(); +} + +void ExecutableConditionOp::print(OpAsmPrinter &p) { + Operation *op = getOperation(); + printTargetConditionRegion(p, op, getBody()); + p.printOptionalAttrDictWithKeyword(op->getAttrs(), + /*elidedAttrs=*/{"function_type"}); +} + +Block *ExecutableConditionOp::addEntryBlock() { + assert(empty() && "function already has an entry block"); + auto *entry = new Block(); + auto argTypes = getArgumentTypes(); + SmallVector argLocs(argTypes.size(), getLoc()); + entry->addArguments(argTypes, argLocs); + push_back(entry); + return entry; +} + +Block *ExecutableConditionOp::addBlock() { + assert(!empty() && "function should at least have an entry block"); + push_back(new Block()); + return &back(); +} + //===----------------------------------------------------------------------===// // hal.executable.constant.block //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index fc27fe41a2ea..29f94572dbae 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -394,10 +394,10 @@ def HAL_DispatchExternOp : HAL_PureOp<"dispatch.extern", [ type($arguments), $argument_dims, type($results), $result_dims, $tied_operands) + `count` `` custom($workgroup_count) `layout` `(` $layout `)` (`bindings` `(` $bindings^ `)`)? `objects` `(` $objects `)` - `count` `` custom($workgroup_count) attr-dict-with-keyword }]; @@ -1534,88 +1534,8 @@ def HAL_DeviceAllocatorOp : HAL_PureOp<"device.allocator", [ ]; } -def HAL_DeviceSwitchOp : HAL_Op<"device.switch", [ - NoRegionArguments, - RecursiveMemoryEffects, - ]> { - let summary = [{runtime device switch pseudo op}]; - let description = [{ - Switches between multiple regions based on the runtime device type. - The provided regions are matched against the runtime backend of the given - device and executed only when the device matches the conditions. - - Conditions can match on wildcards and be folded to enable conditions that - have similar bodies to be folded. The patterns themselves are only matched - once at startup and then the results are cached; the runtime overhead is - equivalent to a normal switch statement. In cases where the compiler can - statically identify the device type entire cases can be folded away. - - Supported conditions: - * `#hal.match...`: execute the region if the expression matches. - - Supported match expressions: - * `#hal.match.always`: always matches; useful for defaults. - * `#hal.match.any<[...]>`: matches if any of the nested expressions match. - * `#hal.match.all<[...]>`: matches only if all of the nested expressions - match. - * `#hal.device.match.id<"pattern*-?-*">`: matches against the device - identifier. The pattern is evaluated with standard file path wildcards - (`*` for zero or more characters and `?` for one character). - - If more than one condition is satisfied the first listed will be chosen. - More specific conditions should be earlier in the set. If no condition is - matched but there are return values the switch will abort at runtime. It's - strongly recommend that all switches that return values end with a trailing - `#hal.match.always` condition to handle the fallthrough case. - - Upon creation each condition region will have an empty entry block with the - specified operands available as arguments. Each region must be setup to - return the same types. - - ```mlir - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %c2 = arith.constant 2 : i32 - %device = ... : !hal.device - %0 = hal.device.switch<%device : !hal.device> -> i32 - #hal.device.match.id<"vulkan-v1.?-*"> { - hal.return %c1 : i32 - }, - #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> { - hal.return %c2 : i32 - }, - #hal.match.always { - hal.return %c0 : i32 - } - ``` - }]; - - let arguments = (ins - HAL_Device:$device, - ArrayAttr:$conditions - ); - let results = (outs - Variadic:$results - ); - - let regions = (region VariadicRegion:$condition_regions); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins - "TypeRange":$resultTypes, - "Value":$device, - "ArrayRef":$conditions, - CArg<"ArrayRef", "{}">:$attributes - )>, - ]; - - - let hasVerifier = 1; -} - def HAL_ReturnOp : HAL_Op<"return", [Terminator]> { - let summary = [{return from a hal.device.switch region}]; + let summary = [{return from a hal.* region}]; let description = [{ Returns the given values from the region and back to the host code. }]; @@ -1955,19 +1875,34 @@ def HAL_ExecutableSourceOp : HAL_Op<"executable.source", [ OptionalAttr:$objects ); - let regions = (region SizedRegion<1>:$body); + let regions = (region + SizedRegion<1>:$body + ); let assemblyFormat = [{ custom($sym_visibility) $sym_name attr-dict-with-keyword `` - regions + $body }]; let extraClassDeclaration = [{ Block& getBlock() { return getBody().front(); } + IREE::HAL::ExecutableConditionOp getConditionOp() { + auto conditionOps = getBody().getOps(); + return !conditionOps.empty() ? *conditionOps.begin() : IREE::HAL::ExecutableConditionOp{}; + } + iterator_range> + getConstantBlockOps() { + return getBody().getOps(); + } + iterator_range> + getExportOps() { + return getBody().getOps(); + } + bool isExternal() { return getBlock().getOps<::mlir::ModuleOp>().empty(); } @@ -1978,6 +1913,8 @@ def HAL_ExecutableSourceOp : HAL_Op<"executable.source", [ return *it.begin(); } }]; + + let hasVerifier = 1; } def HAL_ExecutableSourceEndOp : HAL_Op<"executable.source_end", [ @@ -2114,6 +2051,12 @@ def HAL_ExecutableVariantOp : HAL_Op<"executable.variant", [ let description = [{ The target IR for the executable. This can be preserved for debugging but is usually removed during transformation. + + Variants are selected based on their target and an optional condition + op that returns true if the variant is valid for use on the provided + runtime `!hal.device`. If no variants within an executable are valid then + loading will fail at runtime. If multiple variants are valid the first valid + one found will be loaded and used for execution. }]; let arguments = (ins @@ -2123,7 +2066,9 @@ def HAL_ExecutableVariantOp : HAL_Op<"executable.variant", [ OptionalAttr:$objects ); - let regions = (region SizedRegion<1>:$body); + let regions = (region + SizedRegion<1>:$body + ); let assemblyFormat = [{ custom($sym_visibility) @@ -2131,7 +2076,7 @@ def HAL_ExecutableVariantOp : HAL_Op<"executable.variant", [ `target` `(` $target `)` (`objects` `(` $objects^ `)` )? attr-dict-with-keyword - regions + $body }]; let skipDefaultBuilders = 1; @@ -2142,6 +2087,19 @@ def HAL_ExecutableVariantOp : HAL_Op<"executable.variant", [ let extraClassDeclaration = [{ Block& getBlock() { return getBody().front(); } + IREE::HAL::ExecutableConditionOp getConditionOp() { + auto conditionOps = getBody().getOps(); + return !conditionOps.empty() ? *conditionOps.begin() : IREE::HAL::ExecutableConditionOp{}; + } + iterator_range> + getConstantBlockOps() { + return getBody().getOps(); + } + iterator_range> + getExportOps() { + return getBody().getOps(); + } + bool isExternal() { return getBlock().getOps<::mlir::ModuleOp>().empty(); } @@ -2155,9 +2113,13 @@ def HAL_ExecutableVariantOp : HAL_Op<"executable.variant", [ // Returns a map of constant key attributes to ordinals across all constant // blocks inside the variant. DenseMap gatherConstantOrdinals(); + + // Returns an i1 indicating whether this variant should be selected. + Value buildCondition(Value device, OpBuilder &builder); }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def HAL_ExecutableVariantEndOp : HAL_Op<"executable.variant_end", [ @@ -2168,6 +2130,59 @@ def HAL_ExecutableVariantEndOp : HAL_Op<"executable.variant_end", [ let assemblyFormat = "attr-dict"; } +def HAL_ExecutableConditionOp : HAL_Op<"executable.condition", [ + IsolatedFromAbove, + FunctionOpInterface, + CallableOpInterface, + ]> { + let summary = [{host code to determine if the executable is enabled}]; + let description = [{ + Variants are selected based on their target and this optional condition + op that returns true if the variant is valid for use on the provided + runtime `!hal.device`. If no variants within an executable are valid then + loading will fail at runtime. If multiple variants are valid the first valid + one found will be loaded and used for execution. + }]; + + let arguments = (ins + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + CArg<"ArrayRef", "{}">:$attrs + )>, + ]; + + let extraClassDeclaration = [{ + /// Add an entry block to an empty function and set up the block arguments + /// to match the signature of the function. + Block *addEntryBlock(); + Block *addBlock(); + + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + LogicalResult verifyType() { return success(); } + + Region *getCallableRegion() { return &getBody(); } + ArrayRef getCallableResults() { return getResultTypes(); } + + ::mlir::ArrayAttr getCallableArgAttrs() { return nullptr; } + ::mlir::ArrayAttr getCallableResAttrs() { return nullptr; } + + /// Make symbol optional as this op has no symbol. + bool isOptionalSymbol() { return true; } + }]; + + let hasVerifier = 1; +} + def HAL_ExecutableConstantBlockOp : HAL_Op<"executable.constant.block", [ ParentOneOf<[ diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h index 665625e88044..2d5a196650d9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h @@ -41,7 +41,7 @@ namespace HAL { #include "iree/compiler/Dialect/HAL/IR/HALTypeInterfaces.h.inc" // IWYU pragma: export //===----------------------------------------------------------------------===// -// Enum utilities +// Utilities //===----------------------------------------------------------------------===// // Returns a stable identifier for the MLIR element type or nullopt if the diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir index c163e1440d15..a04e5db6bd0e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir @@ -10,40 +10,6 @@ func.func @device_allocator(%device: !hal.device) -> !hal.allocator { // ----- -// CHECK-LABEL: @device_switch -// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) -func.func @device_switch(%device: !hal.device) -> i32 { - // CHECK-DAG: %[[C0:.+]] = arith.constant 0 - %c0 = arith.constant 0 : i32 - // CHECK-DAG: %[[C1:.+]] = arith.constant 1 - %c1 = arith.constant 1 : i32 - // CHECK-DAG: %[[C2:.+]] = arith.constant 2 - %c2 = arith.constant 2 : i32 - // CHECK: = hal.device.switch<%[[DEVICE]] : !hal.device> -> i32 - %0 = hal.device.switch<%device : !hal.device> -> i32 - // CHECK-NEXT: #hal.device.match.id<"vulkan-v1.?-*"> { - #hal.device.match.id<"vulkan-v1.?-*"> { - // CHECK-NEXT: hal.return %[[C1]] : i32 - hal.return %c1 : i32 - // CHECK-NEXT: }, - }, - // CHECK-NEXT: #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> { - #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> { - // CHECK-NEXT: hal.return %[[C2]] : i32 - hal.return %c2 : i32 - // CHECK-NEXT: }, - }, - // CHECK-NEXT: #hal.match.always { - #hal.match.always { - // CHECK-NEXT: hal.return %[[C0]] : i32 - hal.return %c0 : i32 - // CHECK-NEXT: } - } - return %0 : i32 -} - -// ----- - // CHECK-LABEL: @device_query // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) func.func @device_query(%device : !hal.device) -> (i1, i32) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir index a2590dbe7ee0..acd872d55524 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir @@ -1,6 +1,7 @@ // RUN: iree-opt --split-input-file %s | FileCheck %s #executable_target_format = #hal.executable.target<"backend", "format"> + // CHECK-LABEL: @ex hal.executable @ex { // CHECK: hal.executable.variant public @backend @@ -67,6 +68,47 @@ hal.executable @ex_with_workgroup_count_region { #executable_target_format = #hal.executable.target<"backend", "format"> +// CHECK-LABEL: @ex_with_condition +hal.executable @ex_with_condition { + // CHECK: hal.executable.variant public @backend target(#executable_target_format + hal.executable.variant @backend target(#executable_target_format) { + // CHECK: hal.executable.condition(%[[DEVICE:.+]]: !hal.device) -> i1 { + hal.executable.condition(%device: !hal.device) -> i1 { + // CHECK-NEXT: %[[OK:.+]], %[[VALUE:.+]] = hal.device.query<%[[DEVICE]] + %ok, %value = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32 + // CHECK-NEXT: return %[[OK]] + hal.return %ok : i1 + } + + // CHECK-DAG: hal.executable.export public @entry0 ordinal(0) layout(#pipeline_layout) attributes { + // CHECK-SAME: subgroup_size = 64 : index + // CHECK-SAME: workgroup_size = [4 : index, 1 : index, 1 : index] + hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> + ]>) attributes { + subgroup_size = 64 : index, + workgroup_size = [4 : index, 1 : index, 1 : index] + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): + hal.return %arg0, %arg1, %arg2 : index, index, index + } + } + // CHECK: hal.executable.binary + hal.executable.binary @backend_binary attributes { + // CHECK-SAME: data = dense<1> : vector<128xi8>, + data = dense<1> : vector<128xi8>, + // CHECK-SAME: format = "some_format" + format = "some_format" + } +} + +// ----- + +#executable_target_format = #hal.executable.target<"backend", "format"> + // CHECK-LABEL: @ex_with_constants hal.executable @ex_with_constants { // CHECK: hal.executable.variant public @backend diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir index 86b81c971679..42f1d9e09b32 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir @@ -64,6 +64,15 @@ func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32 // Dispatch workgroups to the externally defined function "main" in the // referenced object files. %0 = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1 + // Translates the workload (%x and %y captured above) into an XYZ workgroup + // count, optionally using device information. + count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) { + // Shows how device queries can be used when computing the workgroup count. + // The device is the one used at runtime. + %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32 + %z = arith.index_cast %z_i32 : i32 to index + hal.return %x_capture, %y_capture, %z : index, index, index + } // Must match the external definition. layout(#hal.pipeline.layout, %arg1: tensor<8xi32>, %arg2: i32 #hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>], #hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>] }>) - // Translates the workload (%x and %y captured above) into an XYZ workgroup - // count, optionally using device information. - count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) { - // Shows how device queries can be used when computing the workgroup count. - // The device is the one used at runtime. - %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32 - %z = arith.index_cast %z_i32 : i32 to index - hal.return %x_capture, %y_capture, %z : index, index, index - } return %0 : tensor<8xi32> } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel index 0ba1199db5ac..655eb9ab3bde 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel @@ -25,7 +25,6 @@ iree_compiler_cc_library( deps = [ "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt index 689783e33819..12aca6d30aa7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt @@ -27,7 +27,6 @@ iree_cc_library( MLIRTransforms iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Util::IR iree::compiler::Utils PUBLIC diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp index 88ee535c3525..b72f31d627b9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp @@ -147,7 +147,7 @@ class ROCMTargetBackend final : public TargetBackend { // Collect all the entry point names. llvm::StringMap exportOps; - for (auto op : variantOp.getOps()) { + for (auto op : variantOp.getExportOps()) { exportOps[op.getSymName()] = op; } std::vector> workgroupSizes; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h index 4c61dcb47690..33fc13677862 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h @@ -14,7 +14,6 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h" #include "iree/compiler/Utils/OptionUtils.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index 829dbefa7b87..51df7ca710b8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -250,7 +250,7 @@ class VulkanSPIRVTargetBackend : public TargetBackend { // Take exported names verbatim for passing into VkShaderModuleCreateInfo. SmallVector entryPointNames; - for (auto exportOp : variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { entryPointNames.emplace_back(exportOp.getSymName()); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp index 0905020df794..8018ca237679 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp @@ -148,8 +148,7 @@ class WebGPUTargetBackend : public TargetBackend { // For each executable entry point op, rename the entry point symbol using // that convention and keep track of the mapping between entry point // ordinals to which shader module they reference. - auto exportOps = - llvm::to_vector(variantOp.getOps()); + auto exportOps = llvm::to_vector(variantOp.getExportOps()); llvm::SmallVector entryPointOrdinals(exportOps.size()); SymbolTableCollection symbolTable; SymbolUserMap symbolUsers(symbolTable, variantOp); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel index f511949fa767..6b69403fe820 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel @@ -22,7 +22,6 @@ iree_compiler_cc_library( "DumpExecutableSources.cpp", "ElideRedundantCommands.cpp", "FixupLegacySync.cpp", - "InlineDeviceSwitches.cpp", "LinkExecutables.cpp", "MaterializeDispatchInstrumentation.cpp", "MaterializeInterfaces.cpp", @@ -52,7 +51,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/HAL/Target", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Stream/Transforms", "//compiler/src/iree/compiler/Dialect/Util/Conversion", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 703768bbd7b3..81bee1b74ef7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -23,7 +23,6 @@ iree_cc_library( "DumpExecutableSources.cpp" "ElideRedundantCommands.cpp" "FixupLegacySync.cpp" - "InlineDeviceSwitches.cpp" "LinkExecutables.cpp" "MaterializeDispatchInstrumentation.cpp" "MaterializeInterfaces.cpp" @@ -63,7 +62,6 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Target - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Stream::Transforms iree::compiler::Dialect::Util::Conversion diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 4bad200fa4a1..1496e284d5be 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -408,7 +408,7 @@ buildBenchmarkModule(IREE::HAL::ExecutableOp sourceExecutableOp, // Add functions to test each entry point with its various dispatch // parameters. bool hasAnyBenchmarks = false; - for (auto exportOp : variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { auto symbolRefAttr = SymbolRefAttr::get(executableOp.getNameAttr(), { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp deleted file mode 100644 index a40e7b8f83ef..000000000000 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include - -#include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" -#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" -#include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "llvm/ADT/StringSet.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace HAL { - -// Inlines a condition region from a switch op into the function at the given -// point. This assumes that the insertion point will only be reached if the -// condition the region is predicated on is true. -static void inlineConditionRegion(Region &conditionRegion, Block *exitBlock, - OpBuilder funcBuilder) { - assert(!conditionRegion.empty() && "source regions must not be empty"); - assert(conditionRegion.front().getNumArguments() == 0 && - "switch does not capture"); - - // Splice in the region blocks. - auto *insertBlock = funcBuilder.getBlock(); - auto postInsertBlockIt = std::next(insertBlock->getIterator())->getIterator(); - auto *insertRegion = insertBlock->getParent(); - insertRegion->getBlocks().splice(postInsertBlockIt, - conditionRegion.getBlocks()); - auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), - postInsertBlockIt); - auto *firstNewBlock = &*newBlocks.begin(); - - // Handle the hal.return ops which will transfer control to the exitBlock. - for (auto &newBlock : newBlocks) { - if (auto returnOp = - dyn_cast(newBlock.getTerminator())) { - OpBuilder branchBuilder(returnOp); - branchBuilder.create(returnOp.getLoc(), exitBlock, - returnOp.getOperands()); - returnOp.erase(); - } - } - - // Splice the instructions of the inlined entry block into the insert block. - insertBlock->getOperations().splice(insertBlock->end(), - firstNewBlock->getOperations()); - firstNewBlock->erase(); -} - -// Inlines each switch condition region into the parent function predicated on -// the switch condition expression. -// -// Since switch conditions are evaluated in the order they are defined we can -// trivially turn the switch into a chain of if-else blocks. -// if condition_0_match: -// -// else -// if condition_1_match: -// -// else ... -static void buildConditionDispatchTable(IREE::HAL::DeviceSwitchOp switchOp, - OpBuilder funcBuilder) { - // Split the block containing the switch op such that all ops before the - // switch are before and the switch and the following ops are after. - // We'll have all of our inlined regions bounce over to the afterBlock with - // the results of the call and use that to replace the switch op. - auto *beforeBlock = funcBuilder.getBlock(); - auto *afterBlock = beforeBlock->splitBlock(switchOp); - SmallVector locs(switchOp.getNumResults(), switchOp.getLoc()); - auto finalValues = llvm::to_vector( - afterBlock->addArguments(switchOp.getResultTypes(), locs)); - - // Create the blocks we'll use for all our conditions so that we can - // reference them when inserting the branch ops. - SmallVector conditionMatchBlocks( - switchOp.getConditionRegions().size()); - SmallVector conditionFallthroughBlocks( - switchOp.getConditionRegions().size()); - for (int i = 0; i < conditionMatchBlocks.size(); ++i) { - conditionMatchBlocks[i] = funcBuilder.createBlock(afterBlock); - conditionFallthroughBlocks[i] = funcBuilder.createBlock(afterBlock); - } - - funcBuilder.setInsertionPoint(beforeBlock, beforeBlock->end()); - for (auto condition : - llvm::enumerate(llvm::zip_equal(switchOp.getConditions().getValue(), - switchOp.getConditionRegions()))) { - auto conditionAttr = llvm::cast( - std::get<0>(condition.value())); - auto &conditionRegion = std::get<1>(condition.value()); - - // Insert the branch based on the match. We either match and jump to a - // block that will contain the inlined region or don't match and need to - // fall through. - auto isMatch = conditionAttr.buildConditionExpression( - switchOp.getLoc(), switchOp.getDevice(), funcBuilder); - auto *matchBlock = conditionMatchBlocks[condition.index()]; - auto *fallthroughBlock = conditionFallthroughBlocks[condition.index()]; - funcBuilder.create(switchOp.getLoc(), isMatch, matchBlock, - fallthroughBlock); - - // Block that contains the inlined region and then jumps out of the chain. - funcBuilder.setInsertionPointToStart(matchBlock); - inlineConditionRegion(conditionRegion, afterBlock, funcBuilder); - - // Block that we enter to check the next condition. - funcBuilder.setInsertionPointToStart(fallthroughBlock); - if (condition.index() + 1 < conditionFallthroughBlocks.size()) { - // Just continue on - the next loop iteration for the following - // condition will add its IR to the block. - } else { - // Fallthrough of all expressions; die if we expected return values. - funcBuilder.create( - switchOp.getLoc(), - "device not supported in the compiled configuration"); - } - } - - // Remove the switch op and replace its results with the final joined - // results. - switchOp.replaceAllUsesWith(finalValues); -} - -class InlineDeviceSwitchesPass - : public PassWrapper> { -public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - StringRef getArgument() const override { - return "iree-hal-inline-device-switches"; - } - - StringRef getDescription() const override { - return "Inlines hal.device.switch condition regions"; - } - - void runOnOperation() override { - auto funcOp = getOperation(); - SmallVector switchOps; - funcOp->walk([&](IREE::HAL::DeviceSwitchOp switchOp) { - switchOps.push_back(switchOp); - }); - for (auto switchOp : switchOps) { - OpBuilder funcBuilder(switchOp); - buildConditionDispatchTable(switchOp, funcBuilder); - switchOp.erase(); - } - } -}; - -std::unique_ptr> createInlineDeviceSwitchesPass() { - return std::make_unique(); -} - -static PassRegistration pass; - -} // namespace HAL -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 63cf4b711a2b..6979939f0d05 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -35,7 +35,7 @@ namespace HAL { namespace { // Map of original SymbolRefAttr to a list of SymbolRefAttrs in variants. -using EntryPointExpansions = DenseMap>; +using ExportExpansions = DenseMap>; //===----------------------------------------------------------------------===// // Linkage utilities @@ -70,7 +70,7 @@ SymbolRefAttr makeExportSymbolRefAttr(IREE::HAL::ExecutableOp executableOp, static LogicalResult materializeExecutableFromSourceOp( IREE::HAL::ExecutableSourceOp sourceOp, ArrayRef targetAttrs, - EntryPointExpansions &entryPointExpansions) { + ExportExpansions &exportExpansions) { OpBuilder moduleBuilder(sourceOp); // Create the op that will contain the translated executable. @@ -80,7 +80,7 @@ static LogicalResult materializeExecutableFromSourceOp( // With this hand-authored path all variants have the same layout and entry // points and we can just clone them. - auto sourceEntryPointOps = sourceOp.getOps(); + auto sourceExportOps = sourceOp.getExportOps(); // Materialize all of the hal.executable.variant ops for all backends we are // targeting. @@ -92,16 +92,15 @@ static LogicalResult materializeExecutableFromSourceOp( sourceOp->getLoc(), targetAttr.getSymbolNameFragment(), targetAttr); targetSymbolTable.insert(targetVariantOp); OpBuilder variantBuilder(&targetVariantOp.getBlock().back()); - for (auto sourceEntryPointOp : sourceEntryPointOps) { - variantBuilder.clone(*sourceEntryPointOp); + for (auto sourceExportOp : sourceExportOps) { + variantBuilder.clone(*sourceExportOp); // Map the original export names to the new variant exports. - entryPointExpansions[SymbolRefAttr::get( - executableOp.getNameAttr(), - {FlatSymbolRefAttr::get( - sourceEntryPointOp.getNameAttr())})] + exportExpansions[SymbolRefAttr::get(executableOp.getNameAttr(), + {FlatSymbolRefAttr::get( + sourceExportOp.getNameAttr())})] .push_back(makeExportSymbolRefAttr(executableOp, targetVariantOp, - sourceEntryPointOp)); + sourceExportOp)); } // Clone any target-specific object files specified. @@ -124,8 +123,9 @@ static LogicalResult materializeExecutableFromSourceOp( return success(); } -static LogicalResult materializeExecutablesFromSourceOps( - mlir::ModuleOp moduleOp, EntryPointExpansions &entryPointExpansions) { +static LogicalResult +materializeExecutablesFromSourceOps(mlir::ModuleOp moduleOp, + ExportExpansions &exportExpansions) { auto sourceOps = llvm::to_vector<32>(moduleOp.getOps()); for (auto sourceOp : sourceOps) { @@ -139,7 +139,7 @@ static LogicalResult materializeExecutablesFromSourceOps( } if (failed(materializeExecutableFromSourceOp(sourceOp, targetAttrs, - entryPointExpansions))) { + exportExpansions))) { return failure(); } } @@ -271,14 +271,13 @@ cloneFuncWithInterface(mlir::func::FuncOp sourceFuncOp, } // Updates the target entry point symbols of |dispatchOp| to the expanded set of -// variant exports in |entryPointExpansions|. -static void -updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp, - const EntryPointExpansions &entryPointExpansions) { +// variant exports in |exportExpansions|. +static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp, + const ExportExpansions &exportExpansions) { SmallVector newAttrs; for (auto oldAttr : dispatchOp.getEntryPointRefs()) { - auto it = entryPointExpansions.find(oldAttr); - if (it == entryPointExpansions.end()) { + auto it = exportExpansions.find(oldAttr); + if (it == exportExpansions.end()) { newAttrs.push_back(oldAttr); // preserve existing continue; } @@ -313,7 +312,7 @@ static LogicalResult declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp, IREE::HAL::ExecutableOp targetExecutableOp, const BindingLayoutAnalysis &layoutAnalysis, - EntryPointExpansions &entryPointExpansions) { + ExportExpansions &exportExpansions) { auto variantOps = targetExecutableOp.getBlock().getOps(); OpBuilder executableBuilder(&targetExecutableOp.getBlock().front()); @@ -387,9 +386,9 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp, /*workgroup_local_memory=*/IntegerAttr{}); // Map the original export name to the new variant export. - entryPointExpansions[SymbolRefAttr::get(sourceExecutableOp.getNameAttr(), - {FlatSymbolRefAttr::get( - exportOp.getNameAttr())})] + exportExpansions[SymbolRefAttr::get( + sourceExecutableOp.getNameAttr(), + {FlatSymbolRefAttr::get(exportOp.getNameAttr())})] .push_back(makeExportSymbolRefAttr(targetExecutableOp, variantOp, newExportOp)); @@ -537,12 +536,12 @@ class MaterializeInterfacesPass void runOnOperation() override { SymbolTable symbolTable(getOperation()); - EntryPointExpansions entryPointExpansions; + ExportExpansions exportExpansions; // Handle any hand-authored executables; these only need variant expansion // and no layout analysis as the user specified the layout themselves. if (failed(materializeExecutablesFromSourceOps(getOperation(), - entryPointExpansions))) { + exportExpansions))) { return signalPassFailure(); } @@ -595,7 +594,7 @@ class MaterializeInterfacesPass // Define interfaces for each exported function based on analysis. if (failed(declareEntryPointOps(sourceOp, executableOp, layoutAnalysis, - entryPointExpansions))) { + exportExpansions))) { return signalPassFailure(); } @@ -615,7 +614,7 @@ class MaterializeInterfacesPass // pipeline layout, though, and any that fall through are errors. auto updateDispatchSites = [&](IREE::Stream::CmdDispatchOp dispatchOp) { // Update the export targets to point at the new variants. - updateDispatchTargets(dispatchOp, entryPointExpansions); + updateDispatchTargets(dispatchOp, exportExpansions); // Annotate the dispatch site with binding information if required. // TODO(benvanik): remove this path; shouldn't be needed in real usage. diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index e96254256e9d..249ff78c1394 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -10,12 +10,10 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" -#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -44,7 +42,7 @@ class MaterializeResourceCachesPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); } @@ -91,8 +89,7 @@ class MaterializeResourceCachesPass for (auto executableOp : executableOps) { for (auto variantOp : executableOp.getOps()) { - for (auto exportOp : - variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { definePipelineLayoutOp(exportOp.getLoc(), exportOp.getLayout()); } } @@ -226,21 +223,33 @@ class MaterializeResourceCachesPass // Each case should then cache only executables which contain a matching // ExecutableVariantOp. // Afterwards, canonicalization will take care of de-duping/etc. - DeviceSwitchBuilder switchBuilder(loc, - /*resultTypes=*/TypeRange{executableType}, - deviceValue, blockBuilder); - for (auto executableVariantOp : + SmallVector caseIndices; + SmallVector caseVariantOps; + for (auto variantOp : executableOp.getOps()) { - auto *region = switchBuilder.addConditionRegion( - executableVariantOp.getTarget().getMatchExpression()); - auto &entryBlock = region->front(); - auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock); + caseIndices.push_back(caseIndices.size()); + caseVariantOps.push_back(variantOp); + } + + // Select the variant index. + Value selectedIndex = buildIfElseTree( + loc, caseVariantOps.size(), + [&](Location loc, size_t i, OpBuilder &builder) { + return caseVariantOps[i].buildCondition(deviceValue, builder); + }, + blockBuilder); + + // Allow each variant to define how it is loaded and what pipeline it has. + auto switchOp = blockBuilder.create( + loc, executableType, selectedIndex, caseIndices, caseIndices.size()); + for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) { + auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); // Gather each of the pipeline layouts needed for each entry point in // the executable. SmallVector pipelineLayoutValues; - for (auto exportOp : - executableVariantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { auto pipelineLayoutGlobalOp = definePipelineLayoutOp(executableOp.getLoc(), exportOp.getLayout()); pipelineLayoutValues.push_back( @@ -253,32 +262,29 @@ class MaterializeResourceCachesPass // We want these to all happen inside of this device switch case; they'll // get deduplicated/hoisted if possible in future canonicalization passes. SmallVector constantValues; - for (auto blockOp : llvm::make_early_inc_range( - executableVariantOp - .getOps())) { + for (auto blockOp : + llvm::make_early_inc_range(variantOp.getConstantBlockOps())) { constantValues.append(inlineConstantBlockOp(blockOp, moduleBuilder, caseBuilder, deviceValue)); blockOp.erase(); } auto executableValue = caseBuilder.createOrFold( - loc, ExecutableType::get(loc.getContext()), deviceValue, - SymbolRefAttr::get( - executableOp.getSymNameAttr(), - {SymbolRefAttr::get(executableVariantOp.getSymNameAttr())}), + loc, executableType, deviceValue, + SymbolRefAttr::get(executableOp.getSymNameAttr(), + {SymbolRefAttr::get(variantOp.getSymNameAttr())}), pipelineLayoutValues, constantValues); - caseBuilder.create(loc, executableValue); + caseBuilder.create(loc, executableValue); } - auto *defaultRegion = switchBuilder.addConditionRegion( - IREE::HAL::MatchAlwaysAttr::get(loc.getContext())); - auto defaultBuilder = OpBuilder::atBlockBegin(&defaultRegion->front()); + // Fallback for no available variant. + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); auto nullValue = defaultBuilder.createOrFold(loc, executableType); - defaultBuilder.create(loc, nullValue); + defaultBuilder.create(loc, nullValue); - auto switchOp = switchBuilder.build(); auto executableValue = switchOp.getResult(0); blockBuilder.create(loc, executableValue, globalOp.getName()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index abc6758062a4..944ae532c1c2 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -318,10 +318,6 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // Device management and specialization //---------------------------------------------------------------------------- - // Inline hal.device.switch ops and memoize their queries such that we can - // better CSE/fold dispatch logic. - FunctionLikeNest(passManager).addPass(createInlineDeviceSwitchesPass); - // Memoize device queries such that we don't need to repeatedly ask the same // information at runtime. passManager.addPass(createMemoizeDeviceQueriesPass()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h index 6b9146fd4caf..0c0cd24b57bf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h @@ -90,9 +90,6 @@ createAssignTargetDevicesPass(const TargetBackendRegistry &targetRegistry, // removed. std::unique_ptr> createFixupLegacySyncPass(); -// Outlines hal.device.switch conditions into functions and inlines conditions. -std::unique_ptr> createInlineDeviceSwitchesPass(); - // Finds hal.device.query ops and creates variables initialized on startup. std::unique_ptr> createMemoizeDeviceQueriesPass(); @@ -208,7 +205,6 @@ inline void registerHALPasses() { createConvertToHALPass(); createDumpExecutableSourcesPass(""); createElideRedundantCommandsPass(); - createInlineDeviceSwitchesPass(); createFixupLegacySyncPass(); createLinkExecutablesPass(TargetBackendRegistry::getGlobal()); createLinkTargetExecutablesPass(TargetBackendRegistry::getGlobal(), ""); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp index d7ce7f5c265a..492799a08527 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp @@ -180,8 +180,8 @@ externalizeExecutableOp(IREE::HAL::ExecutableOp executableOp, variantOp.setObjectsAttr(builder.getArrayAttr({dataObjectAttr})); // Drop the inner module if present (may already be external). - for (auto moduleOp : - llvm::make_early_inc_range(variantOp.getOps())) { + for (auto moduleOp : llvm::make_early_inc_range( + variantOp.getBody().getOps())) { moduleOp.erase(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index 623f210516d7..d9b760af7da9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -23,7 +23,6 @@ iree_lit_test_suite( "dump_executable_sources.mlir", "elide_redundant_commands.mlir", "fixup_legacy_sync.mlir", - "inline_device_switches.mlir", "materialize_dispatch_instrumentation.mlir", "materialize_interfaces.mlir", "materialize_resource_caches.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index 6c2f3420009e..7505915b798d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -21,7 +21,6 @@ iree_lit_test_suite( "dump_executable_sources.mlir" "elide_redundant_commands.mlir" "fixup_legacy_sync.mlir" - "inline_device_switches.mlir" "materialize_dispatch_instrumentation.mlir" "materialize_interfaces.mlir" "materialize_resource_caches.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 84e7e27608e1..1b9e70ab283f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -94,8 +94,12 @@ module attributes {hal.device.targets = [#device_target_cpu]} { %arg1_resource as %arg1_capture: !stream.resource{%c16}, %result_resource as %result_capture: !stream.resource{%c16}) { - // CHECK: hal.device.switch<%[[DEVICE]] : !hal.device> - // CHECK: #hal.device.match.executable.format<"embedded-elf-x86_64"> { + // CHECK-DAG: %{{.+}}, %[[FORMAT_AARCH64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64") + // CHECK-DAG: %{{.+}}, %[[FORMAT_X86_64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64") + // CHECK-DAG: %[[SWITCH1:.+]] = arith.select %[[FORMAT_X86_64]], %c1, %c-1 + // CHECK-DAG: %[[SWITCH0:.+]] = arith.select %[[FORMAT_AARCH64]], %c0, %[[SWITCH1]] + // CHECK: scf.index_switch %[[SWITCH0]] + // CHECK: case 0 { // CHECK: %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.lookup // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: layout(#pipeline_layout) : !hal.pipeline_layout @@ -107,9 +111,14 @@ module attributes {hal.device.targets = [#device_target_cpu]} { // CHECK: %c2 = (%[[RESULT_BUFFER]] : !hal.buffer)[%c0, %c16] // CHECK: ]) // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch) + // CHECK-SAME: target(@ex::@embedded_elf_aarch64::@dispatch) // CHECK-SAME: workgroups([%c1, %c1, %c1]) - // CHECK: hal.return + // CHECK: scf.yield + // CHECK: } + // CHECK: case 1 { + // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch) + // CHECK: scf.yield // CHECK: } stream.cmd.dispatch {@ex::@embedded_elf_aarch64::@dispatch, @ex::@embedded_elf_x86_64::@dispatch}[%c4, %c1, %c1] { ro %arg0_capture[%c0 for %c16] : !stream.resource{%c16}, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir deleted file mode 100644 index 7a06975a1b91..000000000000 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir +++ /dev/null @@ -1,84 +0,0 @@ -// RUN: iree-opt --allow-unregistered-dialect --split-input-file --iree-hal-inline-device-switches --canonicalize %s | FileCheck %s - -// CHECK-LABEL: @simple_constants -// CHECK-SAME: %[[DEVICE:.+]]: !hal.device -// CHECK-SAME: %[[ARG:.+]]: i32 -func.func @simple_constants(%device : !hal.device, %arg : i32) -> i32 { - // CHECK-DAG: %[[C0:.+]] = arith.constant 0 - %c0 = arith.constant 0 : i32 - // CHECK-DAG: %[[C1:.+]] = arith.constant 1 - %c1 = arith.constant 1 : i32 - // CHECK-DAG: %[[C2:.+]] = arith.constant 2 - %c2 = arith.constant 2 : i32 - // CHECK-DAG: %[[C3:.+]] = arith.constant 3 - // CHECK-DAG: %[[C4:.+]] = arith.constant 4 - %0 = hal.device.switch<%device : !hal.device> -> i32 - // CHECK-NEXT: %{{.+}}, %[[IS0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-v1.?-*") : i1, i1 = false - // CHECK-NEXT: cf.cond_br %[[IS0]], ^bb3(%[[C1]] : i32), ^bb1 - #hal.device.match.id<"vulkan-v1.?-*"> { - hal.return %c1 : i32 - }, - // CHECK-NEXT: ^bb1: - // CHECK-NEXT: %{{.+}}, %[[IS1L:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vmvx") : i1, i1 = false - // CHECK-NEXT: %{{.+}}, %[[IS1R:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-*") : i1, i1 = false - // CHECK-NEXT: %[[IS1:.+]] = arith.ori %[[IS1L]], %[[IS1R]] : i1 - // CHECK-NEXT: cf.cond_br %[[IS1]], ^bb2, ^bb3(%[[C0]] : i32) - // CHECK-NEXT: ^bb2: - // CHECK-NEXT: %[[EQZ:.+]] = arith.cmpi eq, %[[ARG]], %[[C2]] : i32 - // CHECK-NEXT: cf.cond_br %[[EQZ]], ^bb3(%[[C3]] : i32), ^bb3(%[[C4]] : i32) - #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> { - %eqz = arith.cmpi eq, %arg, %c2 : i32 - cf.cond_br %eqz, ^bb_true, ^bb_false - ^bb_true: - %c3 = arith.constant 3 : i32 - hal.return %c3 : i32 - ^bb_false: - %c4 = arith.constant 4 : i32 - hal.return %c4 : i32 - }, - #hal.match.always { - hal.return %c0 : i32 - } - // CHECK-NEXT: ^bb3(%[[RES:.+]]: i32): - // CHECK-NEXT: return %[[RES]] : i32 - return %0 : i32 -} - -// ----- - -// CHECK-LABEL: @no_results -// CHECK-SAME: %[[DEVICE:.+]]: !hal.device -func.func @no_results(%device : !hal.device) { - hal.device.switch<%device : !hal.device> - // CHECK-NEXT: %{{.+}}, %[[IS0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-v1.?-*") : i1, i1 = false - // CHECK-NEXT: cf.cond_br %[[IS0]], ^bb1, ^bb2 - // CHECK-NEXT: ^bb1: - // CHECK-NEXT: "some.op_a"() - // CHECK-NEXT: cf.br ^bb5 - #hal.device.match.id<"vulkan-v1.?-*"> { - "some.op_a"() : () -> () - hal.return - }, - // CHECK-NEXT: ^bb2: - // CHECK-NEXT: %{{.+}}, %[[IS1L:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vmvx") : i1, i1 = false - // CHECK-NEXT: %{{.+}}, %[[IS1R:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-*") : i1, i1 = false - // CHECK-NEXT: %[[IS1:.+]] = arith.ori %[[IS1L]], %[[IS1R]] : i1 - // CHECK-NEXT: cf.cond_br %[[IS1]], ^bb3, ^bb4 - // CHECK-NEXT: ^bb3: - // CHECK-NEXT: "some.op_b"() - // CHECK-NEXT: cf.br ^bb5 - #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> { - "some.op_b"() : () -> () - hal.return - }, - // CHECK-NEXT: ^bb4: - // CHECK-NEXT: "some.op_c"() - // CHECK-NEXT: cf.br ^bb5 - #hal.match.always { - "some.op_c"() : () -> () - hal.return - } - // CHECK-NEXT: ^bb5: - // CHECK-NEXT: return - return -} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir index b129edf58db3..e2aa4f845854 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir @@ -122,6 +122,10 @@ module attributes {hal.device.targets = [#hal.device.target<"llvm-cpu">]} { // - If there is no matching hal.executable.variant then the executable will not be cached hal.executable @exe { hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) { + hal.executable.condition(%device: !hal.device) -> i1 { + %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1 + hal.return %selected : i1 + } hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0) attributes { workgroup_size = [32 : index, 1 : index, 1 : index] } @@ -155,9 +159,18 @@ hal.executable @exe { // CHECK: util.global private @_executable_exe : !hal.executable // CHECK-NEXT: util.initializer { + +// Switch on the supported formats: // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device -// CHECK: %[[RET:.+]] = hal.device.switch<%[[DEVICE]] : !hal.device> -> !hal.executable -// CHECK: #hal.device.match.executable.format<"vmvx-bytecode-fb"> { +// CHECK: %{{.+}}, %[[FORMAT_VMVX:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "vmvx-bytecode-fb") +// CHECK: %[[VMVX_CONDITION:.+]] = scf.execute_region -> i1 { +// CHECK: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature") +// CHECK: scf.yield %[[FEATURE]] +// CHECK: } +// CHECK: %[[VMVX_VARIANT_SELECTED:.+]] = arith.andi %[[FORMAT_VMVX]], %[[VMVX_CONDITION]] +// CHECK: %[[VARIANT_INDEX:.+]] = arith.select %[[VMVX_VARIANT_SELECTED]], %c0, %c-1 +// CHECK: %[[RET:.+]] = scf.index_switch %[[VARIANT_INDEX]] -> !hal.executable +// CHECK: case 0 { // Dependent layouts: // CHECK: %[[LAYOUT0:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout @@ -176,11 +189,11 @@ hal.executable @exe { // CHECK-SAME: constants([%[[CONST_01]]#0, %[[CONST_01]]#1, %[[CONST_2]]]) // CHECK-SAME: : !hal.executable -// CHECK: hal.return %[[EXE]] : !hal.executable -// CHECK: }, -// CHECK: #hal.match.always { +// CHECK: scf.yield %[[EXE]] : !hal.executable +// CHECK: } +// CHECK: default { // CHECK: %[[NULL:.+]] = util.null : !hal.executable -// CHECK: hal.return %[[NULL]] : !hal.executable +// CHECK: scf.yield %[[NULL]] : !hal.executable // CHECK: } // CHECK: util.global.store %[[RET]], @_executable_exe : !hal.executable @@ -247,17 +260,21 @@ util.initializer { util.global private @_executable_exe : !hal.executable util.initializer { %device = hal.ex.shared_device : !hal.device - %0 = hal.device.switch<%device : !hal.device> -> !hal.executable - #hal.device.match.executable.format<"vmvx-bytecode-fb"> { + %format_ok, %format_supported = hal.device.query<%device : !hal.device> key("hal.executable.format" :: "some-format") : i1, i1 + %c0 = arith.constant 0 : index + %c-1 = arith.constant -1 : index + %variant = arith.select %format_supported, %c0, %c-1 : index + %selected = scf.index_switch %variant -> !hal.executable + case 0 { %_pipeline_layout_0 = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout %exe = hal.executable.create device(%device : !hal.device) target(@exe0::@vmvx) layouts([%_pipeline_layout_0]) : !hal.executable - hal.return %exe : !hal.executable - }, - #hal.match.always { - %1 = util.null : !hal.executable - hal.return %1 : !hal.executable + scf.yield %exe : !hal.executable + } + default { + %null = util.null : !hal.executable + scf.yield %null : !hal.executable } - util.global.store %0, @_executable_exe : !hal.executable + util.global.store %selected, @_executable_exe : !hal.executable util.initializer.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel deleted file mode 100644 index e7696df4e03b..000000000000 --- a/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2019 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_compiler_cc_library( - name = "Utils", - hdrs = [ - "DeviceSwitchBuilder.h", - ], - deps = [ - "//compiler/src/iree/compiler/Dialect/HAL/IR", - "//compiler/src/iree/compiler/Utils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt deleted file mode 100644 index 5509ab2ad92c..000000000000 --- a/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_cc_library( - NAME - Utils - HDRS - "DeviceSwitchBuilder.h" - DEPS - LLVMSupport - MLIRFuncDialect - MLIRIR - MLIRSupport - MLIRTransforms - iree::compiler::Dialect::HAL::IR - iree::compiler::Utils - PUBLIC -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h b/compiler/src/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h deleted file mode 100644 index abb93b54acda..000000000000 --- a/compiler/src/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_DEVICE_SWITCH_BUILDER_H_ -#define IREE_COMPILER_DIALECT_HAL_UTILS_DEVICE_SWITCH_BUILDER_H_ - -#include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/RegionUtils.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace HAL { - -// See DeviceSwitchBuilder for details. -class DeviceSwitchCaseBuilder { -public: - DeviceSwitchCaseBuilder(Location loc, TypeRange resultTypes, Value device, - Attribute initialCondition, - SmallVectorImpl &caseOps, - OpBuilder &builder) - : loc_(loc), resultTypes_(resultTypes), device_(device), - initialCondition_(initialCondition), caseOps_(caseOps), - builder_(builder) {} - - // Result types that each region must return. - TypeRange resultTypes() { return resultTypes_; } - - // Runtime device the switch will match against. - Value device() { return device_; } - - // Pushes a new condition onto the stack and returns a builder that must have - // all previously nested conditions met in order to execute any conditions. - DeviceSwitchCaseBuilder nest(Attribute conditionAttr) { - auto matchAttr = - initialCondition_ - ? IREE::HAL::MatchAllAttr::get( - conditionAttr.getContext(), - ArrayRef{initialCondition_, conditionAttr}) - : conditionAttr; - return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, matchAttr, - caseOps_, builder_); - } - - // Adds a new condition region that must satisfy all parent conditions. - // The region will have a single empty entry block. - Region *addRegion() { - auto switchOp = builder_.create( - loc_, resultTypes_, device_, ArrayRef{initialCondition_}); - auto *region = &switchOp.getRegion(0); - OpBuilder(region).createBlock(region); - caseOps_.emplace_back(switchOp); - return region; - } - - // Adds a new condition region that must satisfy |conditionAttr| and all - // parent conditions. The region will have a single empty entry block. - Region *addConditionRegion(Attribute conditionAttr) { - return nest(conditionAttr).addRegion(); - } - -private: - Location loc_; - SmallVector resultTypes_; - Value device_; - Attribute initialCondition_; - SmallVectorImpl &caseOps_; - OpBuilder &builder_; -}; - -// Builder for hal.device.switch ops that allows for nesting of conditions. -// -// Example: -// DeviceSwitchBuilder builder(); -// auto b0 = builder.nest(Z); -// b0.addRegion(); // condition: Z -// b0.addConditionRegion(A); // condition: Z && A -// auto b1 = b0.nest(B); -// b1.addConditionRegion(C); // condition: Z && B && C -// b1.addConditionRegion(D); // condition: Z && B && D -// auto b2 = b1.nest(E); -// b2.addRegion(); // condition: Z && B && E -// b2.addConditionRegion(F); // condition: Z && B && E && F -// -// Note that the arguments passed into addRegion/addConditionRegion are captured -// from outside of the switch and accessible as entry block arguments on the -// region that captured them. You must query the returned Region entry block -// arguments to use them within the region. -class DeviceSwitchBuilder { -public: - DeviceSwitchBuilder(Location loc, TypeRange resultTypes, Value device, - OpBuilder builder) - : loc_(loc), resultTypes_(resultTypes), device_(device), - builder_(builder) {} - - // Pushes a new condition onto the stack and returns a builder that must have - // all previously nested conditions met in order to execute any conditions. - DeviceSwitchCaseBuilder nest(Attribute conditionAttr) { - return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, conditionAttr, - caseOps_, builder_); - } - - // Adds a new condition region that must satisfy |conditionAttr| and all - // parent conditions. The region will have a single entry block with the - // given |args|. - Region *addConditionRegion(Attribute conditionAttr) { - return nest(conditionAttr).addRegion(); - } - - // Constructs a single hal.device.switch from all added regions. - IREE::HAL::DeviceSwitchOp build() { - SmallVector conditionAttrs; - llvm::SetVector capturedFromAbove; - for (auto caseOp : caseOps_) { - conditionAttrs.push_back(caseOp.getConditions().getValue()[0]); - } - auto switchOp = builder_.create( - loc_, resultTypes_, device_, conditionAttrs); - for (int i = 0; i < caseOps_.size(); ++i) { - switchOp.getRegion(i).takeBody(caseOps_[i].getRegion(0)); - caseOps_[i].erase(); - } - return switchOp; - } - -private: - Location loc_; - SmallVector resultTypes_; - Value device_; - SmallVector caseOps_; - OpBuilder builder_; -}; - -// Rewriter-compatible version of DeviceSwitchBuilder. -class DeviceSwitchRewriter { -public: - DeviceSwitchRewriter(Location loc, TypeRange resultTypes, Value device, - ConversionPatternRewriter &rewriter) - : loc_(loc), resultTypes_(resultTypes), device_(device), - rewriter_(rewriter) {} - - // Pushes a new condition onto the stack and returns a builder that must have - // all previously nested conditions met in order to execute any conditions. - DeviceSwitchCaseBuilder nest(Attribute conditionAttr) { - return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, conditionAttr, - caseOps_, rewriter_); - } - - // Adds a new condition region that must satisfy |conditionAttr| and all - // parent conditions. The region will have a single empty entry block. - Region *addConditionRegion(Attribute conditionAttr) { - return nest(conditionAttr).addRegion(); - } - - // Constructs a single hal.device.switch from all added regions. - IREE::HAL::DeviceSwitchOp build() { - SmallVector conditionAttrs; - llvm::SetVector capturedFromAbove; - for (auto caseOp : caseOps_) { - conditionAttrs.push_back(caseOp.getConditions().getValue()[0]); - } - auto switchOp = rewriter_.create( - loc_, resultTypes_, device_, conditionAttrs); - for (int i = 0; i < caseOps_.size(); ++i) { - Region &targetRegion = switchOp.getRegion(i); - - SmallVector entryTypes; - Block *entryBlock = - rewriter_.createBlock(&targetRegion, targetRegion.end(), entryTypes); - rewriter_.setInsertionPointAfter(switchOp); - - IRMapping mapper; - - Region &sourceRegion = caseOps_[i].getRegion(0); - // When cloning `sourceRegion` into `targetRegion` remap the captured - // values to use arguments of the `targetRegion`. - rewriter_.cloneRegionBefore(sourceRegion, targetRegion, - ++(Region::iterator(entryBlock)), mapper); - Block *secondBlock = entryBlock->getNextNode(); - rewriter_.mergeBlocks(secondBlock, entryBlock, {}); - rewriter_.eraseOp(caseOps_[i]); - } - return switchOp; - } - - ConversionPatternRewriter &getRewriter() const { return rewriter_; } - -private: - Location loc_; - SmallVector resultTypes_; - Value device_; - SmallVector caseOps_; - ConversionPatternRewriter &rewriter_; -}; - -} // namespace HAL -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_DIALECT_HAL_UTILS_DEVICE_SWITCH_BUILDER_H_ diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 11865dc469a8..5e4f0ead8db0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -25,6 +25,37 @@ namespace mlir { namespace iree_compiler { +//===----------------------------------------------------------------------===// +// Experimental +//===----------------------------------------------------------------------===// + +// For now we emit all cases and then select the first found (by selecting +// in reverse). So if selecting between case0, case1, and case2 we'd end up with +// %case0 = ... +// %case1 = ... +// %case2 = ... +// %0 = arith.select %case2, %c2, %c-1 +// %1 = arith.select %case1, %c1, %0 +// %2 = arith.select %case0, %c0, %1 +// // %2 is now -1 if nothing matched or the index of the match +Value buildIfElseTree( + Location loc, size_t count, + std::function caseBuilder, + OpBuilder &builder) { + SmallVector caseValues; + caseValues.reserve(count); + for (size_t i = 0; i < count; ++i) { + caseValues.push_back(caseBuilder(loc, i, builder)); + } + Value result = builder.create(loc, -1); + for (int i = count - 1; i >= 0; --i) { + result = builder.create( + loc, caseValues[i], builder.create(loc, i), + result); + } + return result; +} + //===----------------------------------------------------------------------===// // Utils //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h index a815359d4bb5..b89675f9b640 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h @@ -30,6 +30,20 @@ namespace mlir { namespace iree_compiler { +//===----------------------------------------------------------------------===// +// Experimental +//===----------------------------------------------------------------------===// + +// NOTE: this is a placeholder for a util.tree_switch (or something) op that +// looks like scf.index_switch but with a region per case. For now we emit a +// sequence of arith.select ops and return the index of the first condition that +// is true. Would be nicer with some range template magic instead of an index. +// Returns an index of -1 if no case matches. +Value buildIfElseTree( + Location loc, size_t count, + std::function caseBuilder, + OpBuilder &builder); + //===----------------------------------------------------------------------===// // Utils //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel index d1998b0b6923..8b3422748845 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel @@ -25,7 +25,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/HAL/Target", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Modules/HAL/Inline/IR", "//compiler/src/iree/compiler/Modules/HAL/Inline/IR:HALInlineDialect", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt index c289898722a9..a895f811989b 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt @@ -28,7 +28,6 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Target - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Util::IR iree::compiler::Modules::HAL::Inline::IR iree::compiler::Modules::HAL::Inline::IR::HALInlineDialect diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel index 2e7554df189d..e21b910914f7 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel @@ -25,7 +25,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/HAL/Target", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Modules/HAL/Inline/IR", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt index 645b1e01384a..a20f7a6714fb 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt @@ -28,7 +28,6 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Target - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::IR iree::compiler::Modules::HAL::Inline::IR diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp index 27880e0c45e5..abb617943fa1 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp @@ -100,7 +100,7 @@ class InlineExecutablesPass auto indexType = innerModuleBuilder.getIndexType(); auto i32Type = innerModuleBuilder.getI32Type(); auto bufferType = innerModuleBuilder.getType(); - for (auto exportOp : variantOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { // Build dispatch function signature that the stream.cmd.dispatch ops will // map to. auto layoutAttr = exportOp.getLayout(); diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel index 49aa5434e23d..12996eee689f 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel @@ -25,7 +25,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/HAL/Target", - "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Modules/HAL/Inline/IR", diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt index a6ffbb4943ce..44ae18267ef5 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt @@ -28,7 +28,6 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Target - iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::IR iree::compiler::Modules::HAL::Inline::IR diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp index 005d4424f14e..7861ee2fd8d5 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp @@ -47,8 +47,8 @@ struct CmdDispatchOpPattern auto loc = dispatchOp.getLoc(); // TODO(benvanik): support a lightweight switch builder for picking variants - // that doesn't pull in the full HAL dialect - today the - // DeviceSwitchRewriter needs a !hal.device and its query methods. + // that doesn't pull in the full HAL dialect. We could make the match + // expressions take a callback that performs the query, for example. // For now we bail if there's multiple. auto entryPointAttrs = dispatchOp.getEntryPoints().getValue(); if (entryPointAttrs.size() != 1) { @@ -76,10 +76,9 @@ struct CmdDispatchOpPattern loc, rewriter.getType(), executableOp.getName()); - // TODO(benvanik): a real switch op. For now we inline what the - // hal.device.switch op does. + // TODO(benvanik): use scf.index_switch as with the full HAL. for (auto variantOp : variantOps) { - auto exportOps = variantOp.getOps(); + auto exportOps = variantOp.getExportOps(); auto exportIt = llvm::find_if(exportOps, [&](IREE::HAL::ExecutableExportOp op) { return op.getNameAttr() == entryPointAttr.getLeafReference(); diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir index b213bd384b36..980df872b23b 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir @@ -53,6 +53,16 @@ module @example attributes {hal.device.targets = [#vulkan_target]} { // Dispatch a basic `ret = lhs * rhs` shader. %0 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor{%dim}, tensor{%dim}) -> tensor{%dim} + count(%device: !hal.device, %workload: index) -> (index, index, index) { + // This host function is used to compute the XYZ workgroup count + // dispatched at runtime. It can query the %device for capabilities + // and limits (shared memory size, etc). The other arguments are the + // values passed in the dispatch operation (usually things like root + // output op tensor dimensions and other abstract values). + %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] + %c1 = arith.constant 1 : index + hal.return %x, %c1, %c1 : index, index, index + } // The layout defines the required bindings and push constants and can be // thought of as the function signature. layout(#hal.pipeline.layout ] }>) - count(%device: !hal.device, %workload: index) -> (index, index, index) { - // This host function is used to compute the XYZ workgroup count - // dispatched at runtime. It can query the %device for capabilities - // and limits (shared memory size, etc). The other arguments are the - // values passed in the dispatch operation (usually things like root - // output op tensor dimensions and other abstract values). - %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] - %c1 = arith.constant 1 : index - hal.return %x, %c1, %c1 : index, index, index - } // Code gen some other ops - these will interleave with the hand-authored // ones but naturally won't be able to fuse with them.