Skip to content

Commit

Permalink
Making flow.dispatch/stream.async.dispatch take multiple symbols. (ir…
Browse files Browse the repository at this point in the history
…ee-org#15295)

stream.cmd.dispatch already supported this for making external
hal.executable.variant ops work and by making this consistent up the
stack it allows for the use of hal.executable.variant all the way up in
flow. This will allow hal.dispatch.extern to expand to HAL ops instead
of flow.executable and avoid the need for plumbing all of the
HAL-specific behavior through those layers.
  • Loading branch information
benvanik authored Oct 25, 2023
1 parent a3a64d2 commit 1b177e9
Show file tree
Hide file tree
Showing 18 changed files with 335 additions and 165 deletions.
28 changes: 28 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,34 @@ void DispatchTensorStoreOp::getCanonicalizationPatterns(
context);
}

//===----------------------------------------------------------------------===//
// flow.dispatch
//===----------------------------------------------------------------------===//

namespace {

struct DeduplicateDispatchEntryRefs final
: public OpRewritePattern<DispatchOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchOp dispatchOp,
PatternRewriter &rewriter) const override {
auto originalAttr = dispatchOp.getEntryPointsAttr();
auto newAttr = deduplicateArrayElements(originalAttr);
if (newAttr == originalAttr)
return failure();
rewriter.updateRootInPlace(
dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); });
return success();
}
};

} // namespace

void DispatchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<DeduplicateDispatchEntryRefs>(context);
}

//===----------------------------------------------------------------------===//
// Tensor ops
//===----------------------------------------------------------------------===//
Expand Down
107 changes: 83 additions & 24 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,44 @@ static void printDispatchWorkgroupsCountRegion(OpAsmPrinter &p, Operation *op,
printWorkgroupCountRegionWithoutKeyword(p, op, body);
}

//===----------------------------------------------------------------------===//
// custom<DispatchEntryPoints>($entry_points)
//===----------------------------------------------------------------------===//

static ParseResult parseDispatchEntryPoints(OpAsmParser &parser,
ArrayAttr &entryPointAttrsArray) {
SmallVector<Attribute> entryPointAttrs;
if (succeeded(parser.parseOptionalLBrace())) {
do {
SymbolRefAttr entryPointAttr;
if (failed(parser.parseAttribute(entryPointAttr)))
return failure();
entryPointAttrs.push_back(entryPointAttr);
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRBrace()))
return failure();
} else {
SymbolRefAttr entryPointAttr;
if (failed(parser.parseAttribute(entryPointAttr)))
return failure();
entryPointAttrs.push_back(entryPointAttr);
}
entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs);
return success();
}

static void printDispatchEntryPoints(OpAsmPrinter &p, Operation *op,
ArrayAttr entryPointAttrs) {
if (entryPointAttrs.size() == 1) {
p.printAttribute(entryPointAttrs.getValue().front());
} else {
p << '{';
llvm::interleaveComma(entryPointAttrs, p.getStream(),
[&](Attribute attr) { p.printAttribute(attr); });
p << '}';
}
}

//===----------------------------------------------------------------------===//
// flow.dispatch.region
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1329,7 +1367,7 @@ void DispatchOp::build(OpBuilder &builder, OperationState &state,
ValueRange operands, ValueRange operandDims,
ArrayAttr tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addAttribute("entry_point", entryPoint);
state.addAttribute("entry_points", builder.getArrayAttr(entryPoint));
state.addOperands(workload);
state.addTypes(resultTypes);
state.addOperands(operands);
Expand All @@ -1349,51 +1387,72 @@ void DispatchOp::build(OpBuilder &builder, OperationState &state,
}));
}

StringAttr DispatchOp::executable() {
return getEntryPoint().getRootReference();
}

FunctionType DispatchOp::getEntryPointType() {
SmallVector<Type, 8> argTypes(operand_type_range{getArguments()});
return FunctionType::get(getContext(), argTypes, getResultTypes());
}

std::string DispatchOp::getEntryPointName() {
// Pick the first entry point we have. The common case is we only have one
// but frontends may provide multiple variants - they're all likely the
// same name but with slight differences and enough for a user to know what's
// happening.
auto anyEntryPoint = *getEntryPointRefs().begin();
std::string entryPointName =
anyEntryPoint.getRootReference().getValue().str();
for (FlatSymbolRefAttr nestedRef : anyEntryPoint.getNestedReferences()) {
entryPointName = (entryPointName + "::" + nestedRef.getValue()).str();
}
return entryPointName;
}

std::pair<unsigned, unsigned> DispatchOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1); // $operands
}

LogicalResult DispatchOp::verify() {
Operation *op = getOperation();

if (getEntryPoints().empty()) {
return op->emitOpError("at least one entry point reference is required");
}

if (failed(verifyOpDynamicDims(op, getArguments(), getArgumentDims())) ||
failed(verifyOpDynamicDims(op, getResults(), getResultDims()))) {
return failure();
}

return success();
}

LogicalResult DispatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *op = getOperation();
auto exportOp =
symbolTable.lookupNearestSymbolFrom<IREE::Flow::ExecutableExportOp>(
op, getEntryPoint());
if (!exportOp) {
// TODO(benvanik): there are a lot of tests that are assuming this is not
// verified. We'll need to go add dummy executables for all of them. Today
// we just bail on the verifier if the symbol isn't found.
//
// Should be:
// return op->emitOpError() << "undefined entry point: " <<
// getEntryPoint();
return success();
}

// Verify that the workload parameters captured match the target export.
if (failed(verifyDispatchWorkload(op, exportOp, getWorkload()))) {
return failure();
}
auto entryPointRefs = getEntryPointRefs();
if (entryPointRefs.empty()) {
return emitOpError() << "at least one entry point must be defined";
}
for (auto entryPointAttr : entryPointRefs) {
auto exportOp =
symbolTable.lookupNearestSymbolFrom<IREE::Flow::ExecutableExportOp>(
op, entryPointAttr);
if (!exportOp) {
// TODO(benvanik): there are a lot of tests that are assuming this is not
// verified. We'll need to go add dummy executables for all of them. Today
// we just bail on the verifier if the symbol isn't found.
//
// Should be:
// return op->emitOpError() << "undefined entry point: " <<
// getEntryPoint();
return success();
}

// TODO(benvanik): verify that the target function has matching operands.
// Verify that the workload parameters captured match the target export.
if (failed(verifyDispatchWorkload(op, exportOp, getWorkload()))) {
return failure();
}

// TODO(benvanik): verify that the target function has matching operands.
}
return success();
}

Expand Down
17 changes: 14 additions & 3 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [

let arguments = (ins
Variadic<FLOW_Dim>:$workload,
SymbolRefAttr:$entry_point,
SymbolRefArrayAttr:$entry_points,
Variadic<AnyType>:$arguments,
FLOW_ShapeDynamicDims:$argument_dims,
FLOW_ShapeDynamicDims:$result_dims,
Expand All @@ -777,7 +777,7 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [
);

let assemblyFormat = [{
$entry_point
custom<DispatchEntryPoints>($entry_points)
(`[` $workload^ `]`)? ``
`(` $arguments `)` attr-dict `:`
custom<ShapedFunctionType>(ref($arguments),
Expand Down Expand Up @@ -827,9 +827,19 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [
];

let extraClassDeclaration = [{
StringAttr executable();
FunctionType getEntryPointType();

auto getEntryPointRefs() {
return getEntryPoints().getAsRange<SymbolRefAttr>();
}
void forEachEntryPointAttr(std::function<void(SymbolRefAttr)> fn) {
for (auto entryPointAttr : getEntryPointRefs()) fn(entryPointAttr);
}

// Returns a human-friendly string name for what is being dispatched.
// May not be unique or a valid reference to an executable.
std::string getEntryPointName();

// StreamableOpInterface:
bool isTransfer() { return false; }

Expand All @@ -841,6 +851,7 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [
}
}];

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down
33 changes: 19 additions & 14 deletions compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// RUN: iree-opt --split-input-file %s --verify-diagnostics | FileCheck %s

flow.executable @ex0 {
flow.executable.export @dispatch_fn
builtin.module {
func.func @dispatch_fn(%cst : index, %arg0 : tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32>
}
}
flow.executable.export @dispatch_fn
}

// CHECK-LABEL: @dispatch
Expand All @@ -18,21 +18,31 @@ func.func @dispatch(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
return %0 : tensor<4xf32>
}

// -----

flow.executable private @ex0 {
flow.executable.export public @dispatch_a
flow.executable.export public @dispatch_b
}

// CHECK-LABEL: @dispatchWithMultipleRefs
func.func @dispatchWithMultipleRefs(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: = flow.dispatch {@ex0::@dispatch_a, @ex0::@dispatch_b}(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
%0 = flow.dispatch {@ex0::@dispatch_a, @ex0::@dispatch_b}(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}


// -----

flow.executable private @ex0 {
flow.executable.export public @dispatch workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
flow.return %arg0, %arg1, %arg0 : index, index, index
}
builtin.module {
func.func @dispatch() {
return
}
}
}

// CHECK-LABEL: @asyncDispatchWithWorkgroupCount
func.func @asyncDispatchWithWorkgroupCount(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
// CHECK-LABEL: @dispatchWithWorkgroupCount
func.func @dispatchWithWorkgroupCount(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
// CHECK: = flow.dispatch @ex0::@dispatch[%c1, %c2](%arg0, %arg1) : (tensor<4xf32>, index) -> tensor<4xf32>
Expand All @@ -46,14 +56,9 @@ flow.executable private @ex0 {
flow.executable.export public @dispatch workgroups(%arg0: index) -> (index, index, index) {
flow.return %arg0, %arg0, %arg0 : index, index, index
}
builtin.module {
func.func @dispatch() {
return
}
}
}

func.func @asyncDispatchWithInvalidWorkload(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
func.func @dispatchWithInvalidWorkload(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
// expected-error @+1 {{op workload mismatch; entry point expects 1 arguments but dispatch provides 2}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,17 @@ class AnnotateDispatchesPass
// new symbol name.
for (auto funcLikeOp : getOperation().getOps<FunctionOpInterface>()) {
funcLikeOp->walk([&](IREE::Flow::DispatchOp dispatchOp) {
auto it = entryPointRefReplacements.find(dispatchOp.getEntryPoint());
if (it != entryPointRefReplacements.end()) {
dispatchOp.setEntryPointAttr(llvm::cast<SymbolRefAttr>(it->second));
SmallVector<Attribute> replacementRefs;
for (auto originalRef : dispatchOp.getEntryPointRefs()) {
auto it = entryPointRefReplacements.find(originalRef);
if (it != entryPointRefReplacements.end()) {
replacementRefs.push_back(it->second);
} else {
replacementRefs.push_back(originalRef);
}
}
dispatchOp.setEntryPointsAttr(
ArrayAttr::get(dispatchOp.getContext(), replacementRefs));
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,20 @@ void replaceEntryPointUses(
const DenseMap<Attribute, SymbolRefAttr> &replacements) {
for (auto funcLikeOp : moduleOp.getOps<FunctionOpInterface>()) {
funcLikeOp->walk([&](DispatchOp dispatchOp) {
auto it = replacements.find(dispatchOp.getEntryPoint());
if (it != replacements.end()) {
dispatchOp.setEntryPointAttr(llvm::cast<SymbolRefAttr>(it->second));
bool didChange = false;
SmallVector<Attribute> newAttrs;
for (auto oldAttr : dispatchOp.getEntryPoints()) {
auto it = replacements.find(oldAttr);
if (it != replacements.end()) {
didChange = true;
newAttrs.push_back(it->second);
} else {
newAttrs.push_back(oldAttr);
}
}
if (didChange) {
dispatchOp.setEntryPointsAttr(
ArrayAttr::get(moduleOp.getContext(), newAttrs));
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ class DumpDispatchGraphPass
void printDispatchBody(raw_ostream &os, DispatchOp &dispatchOp) {
// Find the entry point function from the dispatch entry point symbol
// attribute.
auto entryPoint = dispatchOp.getEntryPoint();
auto entryPoint = *dispatchOp.getEntryPointRefs().begin();
auto executableOp = cast<ExecutableOp>(SymbolTable::lookupNearestSymbolFrom(
dispatchOp, entryPoint.getRootReference()));
if (!executableOp)
Expand Down Expand Up @@ -452,7 +452,7 @@ class DumpDispatchGraphPass
// Print entry function name, if there is only one entry function,
// then the name space and the entry function names are the same,
// and we can just print the function name to save space.
auto entryPoint = dispatch.getEntryPoint();
auto entryPoint = *dispatch.getEntryPointRefs().begin();
auto rootName = entryPoint.getRootReference();
auto leafName = entryPoint.getLeafReference();
if (rootName == leafName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@ class InjectDispatchTracingPass
void runOnOperation() override {
auto funcOp = getOperation();
for (auto dispatchOp : funcOp.getFunctionBody().getOps<DispatchOp>()) {
std::string entryPointName =
dispatchOp.getEntryPoint().getRootReference().getValue().str();
for (FlatSymbolRefAttr nestedRef :
dispatchOp.getEntryPoint().getNestedReferences()) {
entryPointName = (entryPointName + "::" + nestedRef.getValue()).str();
}
std::string entryPointName = dispatchOp.getEntryPointName();

// Input tensors:
OpBuilder builder(dispatchOp);
Expand Down
Loading

0 comments on commit 1b177e9

Please sign in to comment.