From 6634b2b683024dde677dc6979418d44b59032396 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Fri, 13 Dec 2024 01:14:37 -0800 Subject: [PATCH] [CombToAIG] Add a lowering for Add/Sub (#7968) This implements a pattern to lower AddOp to a naive ripple-carry adder. Pattern for sub is also added since we can easily compute subtraction from addition. This PR also adds a test-only option to `additional-legal-ops` to test complicated lowering pattern. --- include/circt/Conversion/CombToAIG.h | 2 + include/circt/Conversion/Passes.td | 5 + .../circt-synth/comb-lowering-lec.mlir | 14 ++ lib/Conversion/CombToAIG/CombToAIG.cpp | 128 +++++++++++++++++- .../CombToAIG/comb-to-aig-arith.mlir | 29 ++++ 5 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 test/Conversion/CombToAIG/comb-to-aig-arith.mlir diff --git a/include/circt/Conversion/CombToAIG.h b/include/circt/Conversion/CombToAIG.h index 1c642130fab8..58bd329e99ca 100644 --- a/include/circt/Conversion/CombToAIG.h +++ b/include/circt/Conversion/CombToAIG.h @@ -10,7 +10,9 @@ #define CIRCT_CONVERSION_COMBTOAIG_H #include "circt/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" #include +#include namespace circt { diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index 5371fd967ce8..0f3dc53a7b64 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -813,6 +813,11 @@ def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> { "circt::comb::CombDialect", "circt::aig::AIGDialect", ]; + + let options = [ + ListOption<"additionalLegalOps", "additional-legal-ops", "std::string", + "Specify additional legal ops for testing">, + ]; } //===----------------------------------------------------------------------===// diff --git a/integration_test/circt-synth/comb-lowering-lec.mlir b/integration_test/circt-synth/comb-lowering-lec.mlir index 0d51431e578f..87d2d497aaa0 100644 --- a/integration_test/circt-synth/comb-lowering-lec.mlir +++ b/integration_test/circt-synth/comb-lowering-lec.mlir @@ -13,3 +13,17 @@ hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i3 hw.output %0, %1, %2, %3 : i32, i32, i32, i32 } + +// RUN: circt-lec %t.mlir %s -c1=add -c2=add --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ADD +// COMB_ADD: c1 == c2 +hw.module @add(in %arg0: i4, in %arg1: i4, in %arg2: i4, out add: i4) { + %0 = comb.add %arg0, %arg1, %arg2 : i4 + hw.output %0 : i4 +} + +// RUN: circt-lec %t.mlir %s -c1=sub -c2=sub --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_SUB +// COMB_SUB: c1 == c2 +hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) { + %0 = comb.sub %lhs, %rhs : i4 + hw.output %0 : i4 +} diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index 19711a068ca2..d9438ed1082f 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -25,6 +25,39 @@ namespace circt { using namespace circt; using namespace comb; +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +// Extract individual bits from a value +static SmallVector extractBits(ConversionPatternRewriter &rewriter, + Value val) { + assert(val.getType().isInteger() && "expected integer"); + auto width = val.getType().getIntOrFloatBitWidth(); + SmallVector bits; + bits.reserve(width); + + // Check if we can reuse concat operands + if (auto concat = val.getDefiningOp()) { + if (concat.getNumOperands() == width && + llvm::all_of(concat.getOperandTypes(), [](Type type) { + return type.getIntOrFloatBitWidth() == 1; + })) { + // Reverse the operands to match the bit order + bits.append(std::make_reverse_iterator(concat.getOperands().end()), + std::make_reverse_iterator(concat.getOperands().begin())); + return bits; + } + } + + // Extract individual bits + for (int64_t i = 0; i < width; ++i) + bits.push_back( + rewriter.createOrFold(val.getLoc(), val, i, 1)); + + return bits; +} + //===----------------------------------------------------------------------===// // Conversion patterns //===----------------------------------------------------------------------===// @@ -169,6 +202,87 @@ struct CombMuxOpConversion : OpConversionPattern { } }; +struct CombAddOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputs = adaptor.getInputs(); + // Lower only when there are two inputs. + // Variadic operands must be lowered in a different pattern. + if (inputs.size() != 2) + return failure(); + + auto width = op.getType().getIntOrFloatBitWidth(); + // Skip a zero width value. + if (width == 0) { + rewriter.replaceOpWithNewOp(op, op.getType(), 0); + return success(); + } + + // Implement a naive Ripple-carry full adder. + Value carry; + + auto aBits = extractBits(rewriter, inputs[0]); + auto bBits = extractBits(rewriter, inputs[1]); + SmallVector results; + results.resize(width); + for (int64_t i = 0; i < width; ++i) { + SmallVector xorOperands = {aBits[i], bBits[i]}; + if (carry) + xorOperands.push_back(carry); + + // sum[i] = xor(carry[i-1], a[i], b[i]) + // NOTE: The result is stored in reverse order. + results[width - i - 1] = + rewriter.create(op.getLoc(), xorOperands, true); + + // If this is the last bit, we are done. + if (i == width - 1) { + break; + } + + // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i]) + Value nextCarry = rewriter.create( + op.getLoc(), ValueRange{aBits[i], bBits[i]}, true); + if (!carry) { + // This is the first bit, so the carry is the next carry. + carry = nextCarry; + continue; + } + + auto aXnorB = rewriter.create( + op.getLoc(), ValueRange{aBits[i], bBits[i]}, true); + auto andOp = rewriter.create( + op.getLoc(), ValueRange{carry, aXnorB}, true); + carry = rewriter.create(op.getLoc(), + ValueRange{andOp, nextCarry}, true); + } + + rewriter.replaceOpWithNewOp(op, results); + return success(); + } +}; + +struct CombSubOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + // Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to: + // sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1)) + // => add(lhs, ~rhs, 1) + auto notRhs = rewriter.create(op.getLoc(), rhs, + /*invert=*/true); + auto one = rewriter.create(op.getLoc(), op.getType(), 1); + rewriter.replaceOpWithNewOp(op, ValueRange{lhs, notRhs, one}, + true); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -179,6 +293,8 @@ namespace { struct ConvertCombToAIGPass : public impl::ConvertCombToAIGBase { void runOnOperation() override; + using ConvertCombToAIGBase::ConvertCombToAIGBase; + using ConvertCombToAIGBase::additionalLegalOps; }; } // namespace @@ -187,8 +303,11 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { // Bitwise Logical Ops CombAndOpConversion, CombOrOpConversion, CombXorOpConversion, CombMuxOpConversion, + // Arithmetic Ops + CombAddOpConversion, CombSubOpConversion, // Variadic ops that must be lowered to binary operations - CombLowerVariadicOp>(patterns.getContext()); + CombLowerVariadicOp, CombLowerVariadicOp>( + patterns.getContext()); } void ConvertCombToAIGPass::runOnOperation() { @@ -196,9 +315,14 @@ void ConvertCombToAIGPass::runOnOperation() { target.addIllegalDialect(); // Keep data movement operations like Extract, Concat and Replicate. target.addLegalOp(); + hw::BitcastOp, hw::ConstantOp>(); target.addLegalDialect(); + // This is a test only option to add logical ops. + if (!additionalLegalOps.empty()) + for (const auto &opName : additionalLegalOps) + target.addLegalOp(OperationName(opName, &getContext())); + RewritePatternSet patterns(&getContext()); populateCombToAIGConversionPatterns(patterns); diff --git a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir new file mode 100644 index 000000000000..1782dc3f216e --- /dev/null +++ b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir @@ -0,0 +1,29 @@ +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux}))" | FileCheck %s +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add}))" | FileCheck %s --check-prefix=ALLOW_ADD + + +// CHECK-LABEL: @add +hw.module @add(in %lhs: i2, in %rhs: i2, out out: i2) { + // CHECK: %[[lhs0:.*]] = comb.extract %lhs from 0 : (i2) -> i1 + // CHECK-NEXT: %[[lhs1:.*]] = comb.extract %lhs from 1 : (i2) -> i1 + // CHECK-NEXT: %[[rhs0:.*]] = comb.extract %rhs from 0 : (i2) -> i1 + // CHECK-NEXT: %[[rhs1:.*]] = comb.extract %rhs from 1 : (i2) -> i1 + // CHECK-NEXT: %[[sum0:.*]] = comb.xor bin %[[lhs0]], %[[rhs0]] : i1 + // CHECK-NEXT: %[[carry0:.*]] = comb.and bin %[[lhs0]], %[[rhs0]] : i1 + // CHECK-NEXT: %[[sum1:.*]] = comb.xor bin %[[lhs1]], %[[rhs1]], %[[carry0]] : i1 + // CHECK-NEXT: %[[concat:.*]] = comb.concat %[[sum1]], %[[sum0]] : i1, i1 + // CHECK-NEXT: hw.output %[[concat]] : i2 + %0 = comb.add %lhs, %rhs : i2 + hw.output %0 : i2 +} + +// CHECK-LABEL: @sub +// ALLOW_ADD-LABEL: @sub +// ALLOW_ADD-NEXT: %[[NOT_RHS:.+]] = aig.and_inv not %rhs +// ALLOW_ADD-NEXT: %[[CONST:.+]] = hw.constant 1 : i4 +// ALLOW_ADD-NEXT: %[[ADD:.+]] = comb.add bin %lhs, %[[NOT_RHS]], %[[CONST]] +// ALLOW_ADD-NEXT: hw.output %[[ADD]] +hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) { + %0 = comb.sub %lhs, %rhs : i4 + hw.output %0 : i4 +}