Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CombToAIG] Add a lowering for Add/Sub #7968

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/circt/Conversion/CombToAIG.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#define CIRCT_CONVERSION_COMBTOAIG_H

#include "circt/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include <string>

namespace circt {

Expand Down
5 changes: 5 additions & 0 deletions include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,11 @@ def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> {
"circt::comb::CombDialect",
"circt::aig::AIGDialect",
];

let options = [
ListOption<"additionalLegalOps", "additional-legal-ops", "std::string",
"Specify additional legal ops for testing">,
];
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,17 @@ hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i3

hw.output %0, %1, %2, %3 : i32, i32, i32, i32
}

// RUN: circt-lec %t.mlir %s -c1=add -c2=add --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ADD
// COMB_ADD: c1 == c2
hw.module @add(in %arg0: i4, in %arg1: i4, in %arg2: i4, out add: i4) {
%0 = comb.add %arg0, %arg1, %arg2 : i4
hw.output %0 : i4
}

// RUN: circt-lec %t.mlir %s -c1=sub -c2=sub --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_SUB
// COMB_SUB: c1 == c2
hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}
128 changes: 126 additions & 2 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,39 @@ namespace circt {
using namespace circt;
using namespace comb;

//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//

// Extract individual bits from a value
static SmallVector<Value> extractBits(ConversionPatternRewriter &rewriter,
Value val) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
SmallVector<Value> bits;
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return bits;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
rewriter.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));

return bits;
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -169,6 +202,87 @@ struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
}
};

struct CombAddOpConversion : OpConversionPattern<AddOp> {
using OpConversionPattern<AddOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputs = adaptor.getInputs();
// Lower only when there are two inputs.
// Variadic operands must be lowered in a different pattern.
if (inputs.size() != 2)
return failure();

auto width = op.getType().getIntOrFloatBitWidth();
// Skip a zero width value.
if (width == 0) {
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
return success();
}

// Implement a naive Ripple-carry full adder.
Value carry;

auto aBits = extractBits(rewriter, inputs[0]);
auto bBits = extractBits(rewriter, inputs[1]);
SmallVector<Value> results;
results.resize(width);
for (int64_t i = 0; i < width; ++i) {
SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
if (carry)
xorOperands.push_back(carry);

// sum[i] = xor(carry[i-1], a[i], b[i])
// NOTE: The result is stored in reverse order.
results[width - i - 1] =
rewriter.create<comb::XorOp>(op.getLoc(), xorOperands, true);

// If this is the last bit, we are done.
if (i == width - 1) {
break;
}

// carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
Value nextCarry = rewriter.create<comb::AndOp>(
op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
if (!carry) {
// This is the first bit, so the carry is the next carry.
carry = nextCarry;
continue;
}

auto aXnorB = rewriter.create<comb::XorOp>(
op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
auto andOp = rewriter.create<comb::AndOp>(
op.getLoc(), ValueRange{carry, aXnorB}, true);
carry = rewriter.create<comb::OrOp>(op.getLoc(),
ValueRange{andOp, nextCarry}, true);
}

rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
return success();
}
};

struct CombSubOpConversion : OpConversionPattern<SubOp> {
using OpConversionPattern<SubOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto lhs = op.getLhs();
auto rhs = op.getRhs();
// Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
// sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
// => add(lhs, ~rhs, 1)
auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
/*invert=*/true);
auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
rewriter.replaceOpWithNewOp<comb::AddOp>(op, ValueRange{lhs, notRhs, one},
true);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -179,6 +293,8 @@ namespace {
struct ConvertCombToAIGPass
: public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
void runOnOperation() override;
using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
};
} // namespace

Expand All @@ -187,18 +303,26 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
// Bitwise Logical Ops
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
CombMuxOpConversion,
// Arithmetic Ops
CombAddOpConversion, CombSubOpConversion,
// Variadic ops that must be lowered to binary operations
CombLowerVariadicOp<XorOp>>(patterns.getContext());
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>>(
patterns.getContext());
}

void ConvertCombToAIGPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<comb::CombDialect>();
// Keep data movement operations like Extract, Concat and Replicate.
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp>();
hw::BitcastOp, hw::ConstantOp>();
target.addLegalDialect<aig::AIGDialect>();

// This is a test only option to add logical ops.
if (!additionalLegalOps.empty())
for (const auto &opName : additionalLegalOps)
target.addLegalOp(OperationName(opName, &getContext()));

RewritePatternSet patterns(&getContext());
populateCombToAIGConversionPatterns(patterns);

Expand Down
29 changes: 29 additions & 0 deletions test/Conversion/CombToAIG/comb-to-aig-arith.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux}))" | FileCheck %s
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add}))" | FileCheck %s --check-prefix=ALLOW_ADD


// CHECK-LABEL: @add
hw.module @add(in %lhs: i2, in %rhs: i2, out out: i2) {
// CHECK: %[[lhs0:.*]] = comb.extract %lhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[lhs1:.*]] = comb.extract %lhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[rhs0:.*]] = comb.extract %rhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[rhs1:.*]] = comb.extract %rhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[sum0:.*]] = comb.xor bin %[[lhs0]], %[[rhs0]] : i1
// CHECK-NEXT: %[[carry0:.*]] = comb.and bin %[[lhs0]], %[[rhs0]] : i1
// CHECK-NEXT: %[[sum1:.*]] = comb.xor bin %[[lhs1]], %[[rhs1]], %[[carry0]] : i1
// CHECK-NEXT: %[[concat:.*]] = comb.concat %[[sum1]], %[[sum0]] : i1, i1
// CHECK-NEXT: hw.output %[[concat]] : i2
%0 = comb.add %lhs, %rhs : i2
hw.output %0 : i2
}

// CHECK-LABEL: @sub
// ALLOW_ADD-LABEL: @sub
// ALLOW_ADD-NEXT: %[[NOT_RHS:.+]] = aig.and_inv not %rhs
// ALLOW_ADD-NEXT: %[[CONST:.+]] = hw.constant 1 : i4
// ALLOW_ADD-NEXT: %[[ADD:.+]] = comb.add bin %lhs, %[[NOT_RHS]], %[[CONST]]
// ALLOW_ADD-NEXT: hw.output %[[ADD]]
hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}
Loading