Skip to content

Commit

Permalink
[CombToAIG] Add a lowering for Add/Sub (#7968)
Browse files Browse the repository at this point in the history
This implements a pattern to lower AddOp to a naive ripple-carry adder.
Pattern for sub is also added since we can easily compute subtraction from addition.

This PR also adds a test-only option to `additional-legal-ops` to test
complicated lowering pattern.
  • Loading branch information
uenoku authored Dec 13, 2024
1 parent b7a0513 commit 6634b2b
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/circt/Conversion/CombToAIG.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#define CIRCT_CONVERSION_COMBTOAIG_H

#include "circt/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include <string>

namespace circt {

Expand Down
5 changes: 5 additions & 0 deletions include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,11 @@ def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> {
"circt::comb::CombDialect",
"circt::aig::AIGDialect",
];

let options = [
ListOption<"additionalLegalOps", "additional-legal-ops", "std::string",
"Specify additional legal ops for testing">,
];
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,17 @@ hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i3

hw.output %0, %1, %2, %3 : i32, i32, i32, i32
}

// RUN: circt-lec %t.mlir %s -c1=add -c2=add --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ADD
// COMB_ADD: c1 == c2
hw.module @add(in %arg0: i4, in %arg1: i4, in %arg2: i4, out add: i4) {
%0 = comb.add %arg0, %arg1, %arg2 : i4
hw.output %0 : i4
}

// RUN: circt-lec %t.mlir %s -c1=sub -c2=sub --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_SUB
// COMB_SUB: c1 == c2
hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}
128 changes: 126 additions & 2 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,39 @@ namespace circt {
using namespace circt;
using namespace comb;

//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//

// Extract individual bits from a value
static SmallVector<Value> extractBits(ConversionPatternRewriter &rewriter,
Value val) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
SmallVector<Value> bits;
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return bits;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
rewriter.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));

return bits;
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -169,6 +202,87 @@ struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
}
};

struct CombAddOpConversion : OpConversionPattern<AddOp> {
using OpConversionPattern<AddOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputs = adaptor.getInputs();
// Lower only when there are two inputs.
// Variadic operands must be lowered in a different pattern.
if (inputs.size() != 2)
return failure();

auto width = op.getType().getIntOrFloatBitWidth();
// Skip a zero width value.
if (width == 0) {
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
return success();
}

// Implement a naive Ripple-carry full adder.
Value carry;

auto aBits = extractBits(rewriter, inputs[0]);
auto bBits = extractBits(rewriter, inputs[1]);
SmallVector<Value> results;
results.resize(width);
for (int64_t i = 0; i < width; ++i) {
SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
if (carry)
xorOperands.push_back(carry);

// sum[i] = xor(carry[i-1], a[i], b[i])
// NOTE: The result is stored in reverse order.
results[width - i - 1] =
rewriter.create<comb::XorOp>(op.getLoc(), xorOperands, true);

// If this is the last bit, we are done.
if (i == width - 1) {
break;
}

// carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
Value nextCarry = rewriter.create<comb::AndOp>(
op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
if (!carry) {
// This is the first bit, so the carry is the next carry.
carry = nextCarry;
continue;
}

auto aXnorB = rewriter.create<comb::XorOp>(
op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
auto andOp = rewriter.create<comb::AndOp>(
op.getLoc(), ValueRange{carry, aXnorB}, true);
carry = rewriter.create<comb::OrOp>(op.getLoc(),
ValueRange{andOp, nextCarry}, true);
}

rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
return success();
}
};

struct CombSubOpConversion : OpConversionPattern<SubOp> {
using OpConversionPattern<SubOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto lhs = op.getLhs();
auto rhs = op.getRhs();
// Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
// sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
// => add(lhs, ~rhs, 1)
auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
/*invert=*/true);
auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
rewriter.replaceOpWithNewOp<comb::AddOp>(op, ValueRange{lhs, notRhs, one},
true);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -179,6 +293,8 @@ namespace {
struct ConvertCombToAIGPass
: public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
void runOnOperation() override;
using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
};
} // namespace

Expand All @@ -187,18 +303,26 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
// Bitwise Logical Ops
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
CombMuxOpConversion,
// Arithmetic Ops
CombAddOpConversion, CombSubOpConversion,
// Variadic ops that must be lowered to binary operations
CombLowerVariadicOp<XorOp>>(patterns.getContext());
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>>(
patterns.getContext());
}

void ConvertCombToAIGPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<comb::CombDialect>();
// Keep data movement operations like Extract, Concat and Replicate.
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp>();
hw::BitcastOp, hw::ConstantOp>();
target.addLegalDialect<aig::AIGDialect>();

// This is a test only option to add logical ops.
if (!additionalLegalOps.empty())
for (const auto &opName : additionalLegalOps)
target.addLegalOp(OperationName(opName, &getContext()));

RewritePatternSet patterns(&getContext());
populateCombToAIGConversionPatterns(patterns);

Expand Down
29 changes: 29 additions & 0 deletions test/Conversion/CombToAIG/comb-to-aig-arith.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux}))" | FileCheck %s
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add}))" | FileCheck %s --check-prefix=ALLOW_ADD


// CHECK-LABEL: @add
hw.module @add(in %lhs: i2, in %rhs: i2, out out: i2) {
// CHECK: %[[lhs0:.*]] = comb.extract %lhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[lhs1:.*]] = comb.extract %lhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[rhs0:.*]] = comb.extract %rhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[rhs1:.*]] = comb.extract %rhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[sum0:.*]] = comb.xor bin %[[lhs0]], %[[rhs0]] : i1
// CHECK-NEXT: %[[carry0:.*]] = comb.and bin %[[lhs0]], %[[rhs0]] : i1
// CHECK-NEXT: %[[sum1:.*]] = comb.xor bin %[[lhs1]], %[[rhs1]], %[[carry0]] : i1
// CHECK-NEXT: %[[concat:.*]] = comb.concat %[[sum1]], %[[sum0]] : i1, i1
// CHECK-NEXT: hw.output %[[concat]] : i2
%0 = comb.add %lhs, %rhs : i2
hw.output %0 : i2
}

// CHECK-LABEL: @sub
// ALLOW_ADD-LABEL: @sub
// ALLOW_ADD-NEXT: %[[NOT_RHS:.+]] = aig.and_inv not %rhs
// ALLOW_ADD-NEXT: %[[CONST:.+]] = hw.constant 1 : i4
// ALLOW_ADD-NEXT: %[[ADD:.+]] = comb.add bin %lhs, %[[NOT_RHS]], %[[CONST]]
// ALLOW_ADD-NEXT: hw.output %[[ADD]]
hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}

0 comments on commit 6634b2b

Please sign in to comment.