diff --git a/include/circt/Conversion/HandshakeToDC.h b/include/circt/Conversion/HandshakeToDC.h index a3bcf3fc3fc2..802a572c2291 100644 --- a/include/circt/Conversion/HandshakeToDC.h +++ b/include/circt/Conversion/HandshakeToDC.h @@ -26,6 +26,22 @@ namespace circt { std::unique_ptr createHandshakeToDCPass(); +namespace handshaketodc { +using ConvertedOps = DenseSet; + +// Runs Handshake to DC conversion on the provided op. `patternBuilder` can be +// used to describe additional patterns to run - typically this will be a +// pattern that converts the container operation (e.g. `op`). +// `configureTarget` can be provided to specialize legalization. +LogicalResult runHandshakeToDC( + mlir::Operation *op, + llvm::function_ref + patternBuilder, + llvm::function_ref configureTarget = {}); +} // namespace handshaketodc + namespace handshake { // Converts 't' into a valid HW type. This is strictly used for converting diff --git a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp index 8545bbf7dfd3..e0b52dd72e45 100644 --- a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp +++ b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp @@ -32,11 +32,10 @@ using namespace circt; using namespace handshake; using namespace dc; using namespace hw; +using namespace handshaketodc; namespace { -using ConvertedOps = DenseSet; - struct DCTuple { DCTuple() = default; DCTuple(Value token, Value data) : token(token), data(data) {} @@ -308,9 +307,6 @@ struct UnitRateConversionPattern : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (op->getNumResults() != 1) - return op->emitOpError("expected single result for pattern to apply"); - llvm::SmallVector inputData; llvm::SmallVector inputTokens; for (auto input : operands) { @@ -337,9 +333,12 @@ struct UnitRateConversionPattern : public ConversionPattern { Operation *newOp = rewriter.create(state); joinedOps->insert(newOp); - // Pack the result token with the output data, and replace the use. - rewriter.replaceOp(op, ValueRange{pack(rewriter, join.getResult(), - newOp->getResults().front())}); + // Pack the result token with the output data, and replace the uses. + llvm::SmallVector results; + for (auto result : newOp->getResults()) + results.push_back(pack(rewriter, join, result)); + + rewriter.replaceOp(op, results); return success(); } @@ -391,9 +390,9 @@ class BufferOpConversion : public DCOpConversionPattern { } }; -class ReturnOpConversion : public DCOpConversionPattern { +class ReturnOpConversion : public OpConversionPattern { public: - using DCOpConversionPattern::DCOpConversionPattern; + using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename handshake::ReturnOp::Adaptor; LogicalResult @@ -504,9 +503,9 @@ static hw::ModulePortInfo getModulePortInfoHS(const TypeConverter &tc, return hw::ModulePortInfo{inputs, outputs}; } -class FuncOpConversion : public DCOpConversionPattern { +class FuncOpConversion : public OpConversionPattern { public: - using DCOpConversionPattern::DCOpConversionPattern; + using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename handshake::FuncOp::Adaptor; // Replaces a handshake.func with a hw.module, converting the argument and @@ -547,53 +546,19 @@ class HandshakeToDCPass : public HandshakeToDCBase { public: void runOnOperation() override { mlir::ModuleOp mod = getOperation(); - - // Maintain the set of operations which has been converted either through - // unit rate conversion, or as part of other conversions. - // Rationale: - // This is needed for all of the arith ops that get created as part of the - // handshake ops (e.g. arith.select for handshake.mux). There's a bit of a - // dilemma here seeing as all operations need to be converted/touched in a - // handshake.func - which is done so by UnitRateConversionPattern (when no - // other pattern applies). However, we obviously don't want to run said - // pattern on these newly created ops since they do not have handshake - // semantics. - ConvertedOps convertedOps; - - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addLegalDialect(); - target.addLegalOp(); - - // The various patterns will insert new operations into the module to - // facilitate the conversion - however, these operations must be - // distinguishable from already converted operations (which may be of the - // same type as the newly inserted operations). To do this, we mark all - // operations which have been converted as legal, and all other operations - // as illegal. - target.markUnknownOpDynamicallyLegal( - [&](Operation *op) { return convertedOps.contains(op); }); - - DCTypeConverter typeConverter; - RewritePatternSet patterns(&getContext()); - - // Add handshake conversion patterns. - // Note: merge/control merge are not supported - these are non-deterministic - // operators and we do not care for them. - patterns - .add( - &getContext(), typeConverter, &convertedOps); - - // ALL other single-result operations are converted via the - // UnitRateConversionPattern. - patterns.add(&getContext(), typeConverter, - &convertedOps); - - if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + auto targetModifier = [](mlir::ConversionTarget &target) { + target.addLegalDialect(); + }; + + auto patternBuilder = [&](TypeConverter &typeConverter, + handshaketodc::ConvertedOps &convertedOps, + RewritePatternSet &patterns) { + patterns.add(typeConverter, + mod.getContext()); + }; + + LogicalResult res = runHandshakeToDC(mod, patternBuilder, targetModifier); + if (failed(res)) signalPassFailure(); } }; @@ -602,3 +567,62 @@ class HandshakeToDCPass : public HandshakeToDCBase { std::unique_ptr circt::createHandshakeToDCPass() { return std::make_unique(); } + +LogicalResult circt::handshaketodc::runHandshakeToDC( + mlir::Operation *op, + llvm::function_ref + patternBuilder, + llvm::function_ref configureTarget) { + // Maintain the set of operations which has been converted either through + // unit rate conversion, or as part of other conversions. + // Rationale: + // This is needed for all of the arith ops that get created as part of the + // handshake ops (e.g. arith.select for handshake.mux). There's a bit of a + // dilemma here seeing as all operations need to be converted/touched in a + // handshake.func - which is done so by UnitRateConversionPattern (when no + // other pattern applies). However, we obviously don't want to run said + // pattern on these newly created ops since they do not have handshake + // semantics. + handshaketodc::ConvertedOps convertedOps; + mlir::MLIRContext *ctx = op->getContext(); + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + + // And any user-specified target adjustments + if (configureTarget) + configureTarget(target); + + // The various patterns will insert new operations into the module to + // facilitate the conversion - however, these operations must be + // distinguishable from already converted operations (which may be of the + // same type as the newly inserted operations). To do this, we mark all + // operations which have been converted as legal, and all other operations + // as illegal. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return convertedOps.contains(op); }); + + DCTypeConverter typeConverter; + RewritePatternSet patterns(ctx); + + // Add handshake conversion patterns. + // Note: merge/control merge are not supported - these are non-deterministic + // operators and we do not care for them. + patterns + .add( + ctx, typeConverter, &convertedOps); + + // ALL other single-result operations are converted via the + // UnitRateConversionPattern. + patterns.add(ctx, typeConverter, &convertedOps); + + // Build any user-specified patterns + patternBuilder(typeConverter, convertedOps, patterns); + return applyPartialConversion(op, target, std::move(patterns)); +}