Skip to content

Commit

Permalink
[CombToAIG] Add mux lowering (#7966)
Browse files Browse the repository at this point in the history
```
c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
```

Add a LEC test for CombToAIG.
  • Loading branch information
uenoku authored Dec 13, 2024
1 parent 7a98d67 commit b900c38
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
15 changes: 15 additions & 0 deletions integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit

// RUN: circt-opt %s --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir
// RUN: circt-lec %t.mlir %s -c1=bit_logical -c2=bit_logical --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_BIT_LOGICAL
// COMB_BIT_LOGICAL: c1 == c2
hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32,
in %cond: i1, out out0: i32, out out1: i32, out out2: i32, out out3: i32) {
%0 = comb.or %arg0, %arg1 : i32
%1 = comb.and %arg0, %arg1 : i32
%2 = comb.xor %arg0, %arg1 : i32
%3 = comb.mux %cond, %arg0, %arg1 : i32

hw.output %0, %1, %2, %3 : i32, i32, i32, i32
}
51 changes: 48 additions & 3 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,48 @@ struct CombXorOpConversion : OpConversionPattern<XorOp> {
}
};

// Lower comb::MuxOp to AIG operations.
struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
using OpConversionPattern<MuxOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(MuxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b)

Value cond = op.getCond();
auto trueVal = op.getTrueValue();
auto falseVal = op.getFalseValue();

if (!op.getType().isInteger()) {
// If the type of the mux is not integer, bitcast the operands first.
auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
trueVal =
rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, trueVal);
falseVal =
rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, falseVal);
}

// Replicate condition if needed
if (!trueVal.getType().isInteger(1))
cond = rewriter.create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
cond);

// c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
true, false);

Value result = rewriter.create<comb::OrOp>(op.getLoc(), lhs, rhs);
// Insert the bitcast if the type of the mux is not integer.
if (result.getType() != op.getType())
result =
rewriter.create<hw::BitcastOp>(op.getLoc(), op.getType(), result);
rewriter.replaceOp(op, result);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -105,15 +147,18 @@ struct ConvertCombToAIGPass
} // namespace

static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
patterns.add<CombAndOpConversion, CombOrOpConversion, CombXorOpConversion>(
patterns.getContext());
patterns.add<
// Bitwise Logical Ops
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
CombMuxOpConversion>(patterns.getContext());
}

void ConvertCombToAIGPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<comb::CombDialect>();
// Keep data movement operations like Extract, Concat and Replicate.
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp>();
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp>();
target.addLegalDialect<aig::AIGDialect>();

RewritePatternSet patterns(&getContext());
Expand Down
15 changes: 15 additions & 0 deletions test/Conversion/CombToAIG/comb-to-aig.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,18 @@ hw.module @pass(in %arg0: i32, in %arg1: i1, out out: i2) {
%3 = comb.and %0, %1, %2 : i2
hw.output %3 : i2
}

// CHECK-LABEL: @mux
hw.module @mux(in %cond: i1, in %high: !hw.array<2xi4>, in %low: !hw.array<2xi4>, out out: !hw.array<2xi4>) {
// CHECK-NEXT: %[[HIGH:.+]] = hw.bitcast %high : (!hw.array<2xi4>) -> i8
// CHECK-NEXT: %[[LOW:.+]] = hw.bitcast %low : (!hw.array<2xi4>) -> i8
// CHECK-NEXT: %[[COND:.+]] = comb.replicate %cond : (i1) -> i8
// CHECK-NEXT: %[[LHS:.+]] = aig.and_inv %[[COND]], %[[HIGH]] : i8
// CHECK-NEXT: %[[RHS:.+]] = aig.and_inv not %[[COND]], %[[LOW]] : i8
// CHECK-NEXT: %[[NAND:.+]] = aig.and_inv not %[[LHS]], not %[[RHS]] : i8
// CHECK-NEXT: %[[NOT:.+]] = aig.and_inv not %[[NAND]] : i8
// CHECK-NEXT: %[[RESULT:.+]] = hw.bitcast %[[NOT]] : (i8) -> !hw.array<2xi4>
// CHECK-NEXT: hw.output %[[RESULT]] : !hw.array<2xi4>
%0 = comb.mux %cond, %high, %low : !hw.array<2xi4>
hw.output %0 : !hw.array<2xi4>
}

0 comments on commit b900c38

Please sign in to comment.