From b900c382dc55dcad3cc3ac619f1c5e9a4d08acbd Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Thu, 12 Dec 2024 23:59:21 -0800 Subject: [PATCH] [CombToAIG] Add mux lowering (#7966) ``` c ? a : b => (replicate(c) & a) | (~replicate(c) & b) ``` Add a LEC test for CombToAIG. --- .../circt-synth/comb-lowering-lec.mlir | 15 ++++++ lib/Conversion/CombToAIG/CombToAIG.cpp | 51 +++++++++++++++++-- test/Conversion/CombToAIG/comb-to-aig.mlir | 15 ++++++ 3 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 integration_test/circt-synth/comb-lowering-lec.mlir diff --git a/integration_test/circt-synth/comb-lowering-lec.mlir b/integration_test/circt-synth/comb-lowering-lec.mlir new file mode 100644 index 000000000000..4891299e7d63 --- /dev/null +++ b/integration_test/circt-synth/comb-lowering-lec.mlir @@ -0,0 +1,15 @@ +// REQUIRES: libz3 +// REQUIRES: circt-lec-jit + +// RUN: circt-opt %s --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir +// RUN: circt-lec %t.mlir %s -c1=bit_logical -c2=bit_logical --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_BIT_LOGICAL +// COMB_BIT_LOGICAL: c1 == c2 +hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32, + in %cond: i1, out out0: i32, out out1: i32, out out2: i32, out out3: i32) { + %0 = comb.or %arg0, %arg1 : i32 + %1 = comb.and %arg0, %arg1 : i32 + %2 = comb.xor %arg0, %arg1 : i32 + %3 = comb.mux %cond, %arg0, %arg1 : i32 + + hw.output %0, %1, %2, %3 : i32, i32, i32, i32 +} diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index 79ba289bca04..94ec4f85461e 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -91,6 +91,48 @@ struct CombXorOpConversion : OpConversionPattern { } }; +// Lower comb::MuxOp to AIG operations. +struct CombMuxOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(MuxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b) + + Value cond = op.getCond(); + auto trueVal = op.getTrueValue(); + auto falseVal = op.getFalseValue(); + + if (!op.getType().isInteger()) { + // If the type of the mux is not integer, bitcast the operands first. + auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType())); + trueVal = + rewriter.create(op->getLoc(), widthType, trueVal); + falseVal = + rewriter.create(op->getLoc(), widthType, falseVal); + } + + // Replicate condition if needed + if (!trueVal.getType().isInteger(1)) + cond = rewriter.create(op.getLoc(), trueVal.getType(), + cond); + + // c ? a : b => (replicate(c) & a) | (~replicate(c) & b) + auto lhs = rewriter.create(op.getLoc(), cond, trueVal); + auto rhs = rewriter.create(op.getLoc(), cond, falseVal, + true, false); + + Value result = rewriter.create(op.getLoc(), lhs, rhs); + // Insert the bitcast if the type of the mux is not integer. + if (result.getType() != op.getType()) + result = + rewriter.create(op.getLoc(), op.getType(), result); + rewriter.replaceOp(op, result); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -105,15 +147,18 @@ struct ConvertCombToAIGPass } // namespace static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add< + // Bitwise Logical Ops + CombAndOpConversion, CombOrOpConversion, CombXorOpConversion, + CombMuxOpConversion>(patterns.getContext()); } void ConvertCombToAIGPass::runOnOperation() { ConversionTarget target(getContext()); target.addIllegalDialect(); // Keep data movement operations like Extract, Concat and Replicate. - target.addLegalOp(); + target.addLegalOp(); target.addLegalDialect(); RewritePatternSet patterns(&getContext()); diff --git a/test/Conversion/CombToAIG/comb-to-aig.mlir b/test/Conversion/CombToAIG/comb-to-aig.mlir index e4d5b708474d..4628603cb1e7 100644 --- a/test/Conversion/CombToAIG/comb-to-aig.mlir +++ b/test/Conversion/CombToAIG/comb-to-aig.mlir @@ -28,3 +28,18 @@ hw.module @pass(in %arg0: i32, in %arg1: i1, out out: i2) { %3 = comb.and %0, %1, %2 : i2 hw.output %3 : i2 } + +// CHECK-LABEL: @mux +hw.module @mux(in %cond: i1, in %high: !hw.array<2xi4>, in %low: !hw.array<2xi4>, out out: !hw.array<2xi4>) { + // CHECK-NEXT: %[[HIGH:.+]] = hw.bitcast %high : (!hw.array<2xi4>) -> i8 + // CHECK-NEXT: %[[LOW:.+]] = hw.bitcast %low : (!hw.array<2xi4>) -> i8 + // CHECK-NEXT: %[[COND:.+]] = comb.replicate %cond : (i1) -> i8 + // CHECK-NEXT: %[[LHS:.+]] = aig.and_inv %[[COND]], %[[HIGH]] : i8 + // CHECK-NEXT: %[[RHS:.+]] = aig.and_inv not %[[COND]], %[[LOW]] : i8 + // CHECK-NEXT: %[[NAND:.+]] = aig.and_inv not %[[LHS]], not %[[RHS]] : i8 + // CHECK-NEXT: %[[NOT:.+]] = aig.and_inv not %[[NAND]] : i8 + // CHECK-NEXT: %[[RESULT:.+]] = hw.bitcast %[[NOT]] : (i8) -> !hw.array<2xi4> + // CHECK-NEXT: hw.output %[[RESULT]] : !hw.array<2xi4> + %0 = comb.mux %cond, %high, %low : !hw.array<2xi4> + hw.output %0 : !hw.array<2xi4> +}