diff --git a/integration_test/circt-synth/comb-lowering-lec.mlir b/integration_test/circt-synth/comb-lowering-lec.mlir index 87d2d497aaa0..9db9667c7be9 100644 --- a/integration_test/circt-synth/comb-lowering-lec.mlir +++ b/integration_test/circt-synth/comb-lowering-lec.mlir @@ -27,3 +27,10 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) { %0 = comb.sub %lhs, %rhs : i4 hw.output %0 : i4 } + +// RUN: circt-lec %t.mlir %s -c1=mul -c2=mul --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MUL +// COMB_MUL: c1 == c2 +hw.module @mul(in %arg0: i3, in %arg1: i3, in %arg2: i3, out add: i3) { + %0 = comb.mul %arg0, %arg1, %arg2 : i3 + hw.output %0 : i3 +} diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index d9438ed1082f..336512fcc72f 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -283,6 +283,51 @@ struct CombSubOpConversion : OpConversionPattern { } }; +struct CombMulOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + LogicalResult + matchAndRewrite(MulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getInputs().size() != 2) + return failure(); + + // FIXME: Currently it's lowered to a really naive implementation that + // chains add operations. + + // a_{n}a_{n-1}...a_0 * b + // = sum_{i=0}^{n} a_i * 2^i * b + // = sum_{i=0}^{n} (a_i ? b : 0) << i + int64_t width = op.getType().getIntOrFloatBitWidth(); + auto aBits = extractBits(rewriter, adaptor.getInputs()[0]); + SmallVector results; + auto rhs = op.getInputs()[1]; + auto zero = rewriter.create(op.getLoc(), + llvm::APInt::getZero(width)); + for (int64_t i = 0; i < width; ++i) { + auto aBit = aBits[i]; + auto andBit = + rewriter.createOrFold(op.getLoc(), aBit, rhs, zero); + auto upperBits = rewriter.createOrFold( + op.getLoc(), andBit, 0, width - i); + if (i == 0) { + results.push_back(upperBits); + continue; + } + + auto lowerBits = + rewriter.create(op.getLoc(), APInt::getZero(i)); + + auto shifted = rewriter.createOrFold( + op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits}); + results.push_back(shifted); + } + + rewriter.replaceOpWithNewOp(op, results, true); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -304,9 +349,9 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { CombAndOpConversion, CombOrOpConversion, CombXorOpConversion, CombMuxOpConversion, // Arithmetic Ops - CombAddOpConversion, CombSubOpConversion, + CombAddOpConversion, CombSubOpConversion, CombMulOpConversion, // Variadic ops that must be lowered to binary operations - CombLowerVariadicOp, CombLowerVariadicOp>( + CombLowerVariadicOp, CombLowerVariadicOp, CombLowerVariadicOp>( patterns.getContext()); } diff --git a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir index 1782dc3f216e..59d956e953a0 100644 --- a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir +++ b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir @@ -27,3 +27,22 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) { %0 = comb.sub %lhs, %rhs : i4 hw.output %0 : i4 } + + +// CHECK-LABEL: @mul +// ALLOW_ADD-LABEL: @mul +// ALLOW_ADD-NEXT: %[[EXT_0:.+]] = comb.extract %lhs from 0 : (i2) -> i1 +// ALLOW_ADD-NEXT: %[[EXT_1:.+]] = comb.extract %lhs from 1 : (i2) -> i1 +// ALLOW_ADD-NEXT: %c0_i2 = hw.constant 0 : i2 +// ALLOW_ADD-NEXT: %[[MUX_0:.+]] = comb.mux %0, %rhs, %c0_i2 : i2 +// ALLOW_ADD-NEXT: %[[MUX_1:.+]] = comb.mux %1, %rhs, %c0_i2 : i2 +// ALLOW_ADD-NEXT: %[[EXT_MUX_1:.+]] = comb.extract %3 from 0 : (i2) -> i1 +// ALLOW_ADD-NEXT: %false = hw.constant false +// ALLOW_ADD-NEXT: %[[SHIFT:.+]] = comb.concat %4, %false : i1, i1 +// ALLOW_ADD-NEXT: %[[ADD:.+]] = comb.add bin %[[MUX_0]], %[[SHIFT]] : i2 +// ALLOW_ADD-NEXT: hw.output %[[ADD]] : i2 +// ALLOW_ADD-NEXT: } +hw.module @mul(in %lhs: i2, in %rhs: i2, out out: i2) { + %0 = comb.mul %lhs, %rhs : i2 + hw.output %0 : i2 +}