Skip to content

Commit

Permalink
[HandshakeToDC] Turn conversion into library-style
Browse files Browse the repository at this point in the history
Factors out the conversion s.t. it can be run on arbitrary operations. To facilitate this, the conversion will accept builders for additional patterns to apply during the converison process. Having this, users may specify patterns for specific ops (to avoid applying the default unit-rate actor pattern to said ops).
  • Loading branch information
mortbopet committed Oct 2, 2023
1 parent cefec09 commit c02d457
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 59 deletions.
16 changes: 16 additions & 0 deletions include/circt/Conversion/HandshakeToDC.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ namespace circt {

std::unique_ptr<mlir::Pass> createHandshakeToDCPass();

namespace handshaketodc {
using ConvertedOps = DenseSet<Operation *>;

// Runs Handshake to DC conversion on the provided op. `patternBuilder` can be
// used to describe additional patterns to run - typically this will be a
// pattern that converts the container operation (e.g. `op`).
// `configureTarget` can be provided to specialize legalization.
LogicalResult runHandshakeToDC(
mlir::Operation *op,
llvm::function_ref<void(TypeConverter &typeConverter,
ConvertedOps &convertedOps,
RewritePatternSet &patterns)>
patternBuilder,
llvm::function_ref<void(mlir::ConversionTarget &)> configureTarget = {});
} // namespace handshaketodc

namespace handshake {

// Converts 't' into a valid HW type. This is strictly used for converting
Expand Down
142 changes: 83 additions & 59 deletions lib/Conversion/HandshakeToDC/HandshakeToDC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ using namespace circt;
using namespace handshake;
using namespace dc;
using namespace hw;
using namespace handshaketodc;

namespace {

using ConvertedOps = DenseSet<Operation *>;

struct DCTuple {
DCTuple() = default;
DCTuple(Value token, Value data) : token(token), data(data) {}
Expand Down Expand Up @@ -308,9 +307,6 @@ struct UnitRateConversionPattern : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return op->emitOpError("expected single result for pattern to apply");

llvm::SmallVector<Value, 4> inputData;
llvm::SmallVector<Value, 4> inputTokens;
for (auto input : operands) {
Expand All @@ -337,9 +333,12 @@ struct UnitRateConversionPattern : public ConversionPattern {
Operation *newOp = rewriter.create(state);
joinedOps->insert(newOp);

// Pack the result token with the output data, and replace the use.
rewriter.replaceOp(op, ValueRange{pack(rewriter, join.getResult(),
newOp->getResults().front())});
// Pack the result token with the output data, and replace the uses.
llvm::SmallVector<Value> results;
for (auto result : newOp->getResults())
results.push_back(pack(rewriter, join, result));

rewriter.replaceOp(op, results);

return success();
}
Expand Down Expand Up @@ -391,9 +390,9 @@ class BufferOpConversion : public DCOpConversionPattern<handshake::BufferOp> {
}
};

class ReturnOpConversion : public DCOpConversionPattern<handshake::ReturnOp> {
class ReturnOpConversion : public OpConversionPattern<handshake::ReturnOp> {
public:
using DCOpConversionPattern<handshake::ReturnOp>::DCOpConversionPattern;
using OpConversionPattern<handshake::ReturnOp>::OpConversionPattern;
using OpAdaptor = typename handshake::ReturnOp::Adaptor;

LogicalResult
Expand Down Expand Up @@ -504,9 +503,9 @@ static hw::ModulePortInfo getModulePortInfoHS(const TypeConverter &tc,
return hw::ModulePortInfo{inputs, outputs};
}

class FuncOpConversion : public DCOpConversionPattern<handshake::FuncOp> {
class FuncOpConversion : public OpConversionPattern<handshake::FuncOp> {
public:
using DCOpConversionPattern<handshake::FuncOp>::DCOpConversionPattern;
using OpConversionPattern<handshake::FuncOp>::OpConversionPattern;
using OpAdaptor = typename handshake::FuncOp::Adaptor;

// Replaces a handshake.func with a hw.module, converting the argument and
Expand Down Expand Up @@ -547,53 +546,19 @@ class HandshakeToDCPass : public HandshakeToDCBase<HandshakeToDCPass> {
public:
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();

// Maintain the set of operations which has been converted either through
// unit rate conversion, or as part of other conversions.
// Rationale:
// This is needed for all of the arith ops that get created as part of the
// handshake ops (e.g. arith.select for handshake.mux). There's a bit of a
// dilemma here seeing as all operations need to be converted/touched in a
// handshake.func - which is done so by UnitRateConversionPattern (when no
// other pattern applies). However, we obviously don't want to run said
// pattern on these newly created ops since they do not have handshake
// semantics.
ConvertedOps convertedOps;

ConversionTarget target(getContext());
target.addIllegalDialect<handshake::HandshakeDialect>();
target.addLegalDialect<dc::DCDialect, func::FuncDialect, hw::HWDialect>();
target.addLegalOp<mlir::ModuleOp>();

// The various patterns will insert new operations into the module to
// facilitate the conversion - however, these operations must be
// distinguishable from already converted operations (which may be of the
// same type as the newly inserted operations). To do this, we mark all
// operations which have been converted as legal, and all other operations
// as illegal.
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return convertedOps.contains(op); });

DCTypeConverter typeConverter;
RewritePatternSet patterns(&getContext());

// Add handshake conversion patterns.
// Note: merge/control merge are not supported - these are non-deterministic
// operators and we do not care for them.
patterns
.add<FuncOpConversion, BufferOpConversion, CondBranchConversionPattern,
SinkOpConversionPattern, SourceOpConversionPattern,
MuxOpConversionPattern, ReturnOpConversion,
ForkOpConversionPattern, JoinOpConversion,
ControlMergeOpConversion, ConstantOpConversion, SyncOpConversion>(
&getContext(), typeConverter, &convertedOps);

// ALL other single-result operations are converted via the
// UnitRateConversionPattern.
patterns.add<UnitRateConversionPattern>(&getContext(), typeConverter,
&convertedOps);

if (failed(applyPartialConversion(mod, target, std::move(patterns))))
auto targetModifier = [](mlir::ConversionTarget &target) {
target.addLegalDialect<hw::HWDialect, func::FuncDialect>();
};

auto patternBuilder = [&](TypeConverter &typeConverter,
handshaketodc::ConvertedOps &convertedOps,
RewritePatternSet &patterns) {
patterns.add<FuncOpConversion, ReturnOpConversion>(typeConverter,
mod.getContext());
};

LogicalResult res = runHandshakeToDC(mod, patternBuilder, targetModifier);
if (failed(res))
signalPassFailure();
}
};
Expand All @@ -602,3 +567,62 @@ class HandshakeToDCPass : public HandshakeToDCBase<HandshakeToDCPass> {
std::unique_ptr<mlir::Pass> circt::createHandshakeToDCPass() {
return std::make_unique<HandshakeToDCPass>();
}

LogicalResult circt::handshaketodc::runHandshakeToDC(
mlir::Operation *op,
llvm::function_ref<void(TypeConverter &typeConverter,
handshaketodc::ConvertedOps &convertedOps,
RewritePatternSet &patterns)>
patternBuilder,
llvm::function_ref<void(mlir::ConversionTarget &)> configureTarget) {
// Maintain the set of operations which has been converted either through
// unit rate conversion, or as part of other conversions.
// Rationale:
// This is needed for all of the arith ops that get created as part of the
// handshake ops (e.g. arith.select for handshake.mux). There's a bit of a
// dilemma here seeing as all operations need to be converted/touched in a
// handshake.func - which is done so by UnitRateConversionPattern (when no
// other pattern applies). However, we obviously don't want to run said
// pattern on these newly created ops since they do not have handshake
// semantics.
handshaketodc::ConvertedOps convertedOps;
mlir::MLIRContext *ctx = op->getContext();
ConversionTarget target(*ctx);
target.addIllegalDialect<handshake::HandshakeDialect>();
target.addLegalDialect<dc::DCDialect>();
target.addLegalOp<mlir::ModuleOp>();

// And any user-specified target adjustments
if (configureTarget)
configureTarget(target);

// The various patterns will insert new operations into the module to
// facilitate the conversion - however, these operations must be
// distinguishable from already converted operations (which may be of the
// same type as the newly inserted operations). To do this, we mark all
// operations which have been converted as legal, and all other operations
// as illegal.
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return convertedOps.contains(op); });

DCTypeConverter typeConverter;
RewritePatternSet patterns(ctx);

// Add handshake conversion patterns.
// Note: merge/control merge are not supported - these are non-deterministic
// operators and we do not care for them.
patterns
.add<BufferOpConversion, CondBranchConversionPattern,
SinkOpConversionPattern, SourceOpConversionPattern,
MuxOpConversionPattern, ForkOpConversionPattern, JoinOpConversion,
ControlMergeOpConversion, ConstantOpConversion, SyncOpConversion>(
ctx, typeConverter, &convertedOps);

// ALL other single-result operations are converted via the
// UnitRateConversionPattern.
patterns.add<UnitRateConversionPattern>(ctx, typeConverter, &convertedOps);

// Build any user-specified patterns
patternBuilder(typeConverter, convertedOps, patterns);
return applyPartialConversion(op, target, std::move(patterns));
}

0 comments on commit c02d457

Please sign in to comment.