Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CombToAIG] Add a pattern for mul #8015

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,10 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}

// RUN: circt-lec %t.mlir %s -c1=mul -c2=mul --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MUL
// COMB_MUL: c1 == c2
hw.module @mul(in %arg0: i3, in %arg1: i3, in %arg2: i3, out add: i3) {
%0 = comb.mul %arg0, %arg1, %arg2 : i3
hw.output %0 : i3
}
51 changes: 48 additions & 3 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,51 @@ struct CombSubOpConversion : OpConversionPattern<SubOp> {
}
};

struct CombMulOpConversion : OpConversionPattern<MulOp> {
using OpConversionPattern<MulOp>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
LogicalResult
matchAndRewrite(MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getInputs().size() != 2)
return failure();

// FIXME: Currently it's lowered to a really naive implementation that
// chains add operations.

// a_{n}a_{n-1}...a_0 * b
// = sum_{i=0}^{n} a_i * 2^i * b
// = sum_{i=0}^{n} (a_i ? b : 0) << i
int64_t width = op.getType().getIntOrFloatBitWidth();
auto aBits = extractBits(rewriter, adaptor.getInputs()[0]);
SmallVector<Value> results;
auto rhs = op.getInputs()[1];
auto zero = rewriter.create<hw::ConstantOp>(op.getLoc(),
llvm::APInt::getZero(width));
for (int64_t i = 0; i < width; ++i) {
auto aBit = aBits[i];
auto andBit =
rewriter.createOrFold<comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
auto upperBits = rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), andBit, 0, width - i);
if (i == 0) {
results.push_back(upperBits);
continue;
}

auto lowerBits =
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(i));

auto shifted = rewriter.createOrFold<comb::ConcatOp>(
op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
results.push_back(shifted);
}

rewriter.replaceOpWithNewOp<comb::AddOp>(op, results, true);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -304,10 +349,10 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
CombMuxOpConversion,
// Arithmetic Ops
CombAddOpConversion, CombSubOpConversion,
CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
// Variadic ops that must be lowered to binary operations
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>>(
patterns.getContext());
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
CombLowerVariadicOp<MulOp>>(patterns.getContext());
}

void ConvertCombToAIGPass::runOnOperation() {
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/CombToAIG/comb-to-aig-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,22 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}


// CHECK-LABEL: @mul
// ALLOW_ADD-LABEL: @mul
// ALLOW_ADD-NEXT: %[[EXT_0:.+]] = comb.extract %lhs from 0 : (i2) -> i1
// ALLOW_ADD-NEXT: %[[EXT_1:.+]] = comb.extract %lhs from 1 : (i2) -> i1
// ALLOW_ADD-NEXT: %c0_i2 = hw.constant 0 : i2
// ALLOW_ADD-NEXT: %[[MUX_0:.+]] = comb.mux %0, %rhs, %c0_i2 : i2
// ALLOW_ADD-NEXT: %[[MUX_1:.+]] = comb.mux %1, %rhs, %c0_i2 : i2
// ALLOW_ADD-NEXT: %[[EXT_MUX_1:.+]] = comb.extract %3 from 0 : (i2) -> i1
// ALLOW_ADD-NEXT: %false = hw.constant false
// ALLOW_ADD-NEXT: %[[SHIFT:.+]] = comb.concat %4, %false : i1, i1
// ALLOW_ADD-NEXT: %[[ADD:.+]] = comb.add bin %[[MUX_0]], %[[SHIFT]] : i2
// ALLOW_ADD-NEXT: hw.output %[[ADD]] : i2
// ALLOW_ADD-NEXT: }
hw.module @mul(in %lhs: i2, in %rhs: i2, out out: i2) {
%0 = comb.mul %lhs, %rhs : i2
hw.output %0 : i2
}
Loading