diff --git a/integration_test/circt-synth/comb-lowering-lec.mlir b/integration_test/circt-synth/comb-lowering-lec.mlir index 9db9667c7be9..3c7e047ec792 100644 --- a/integration_test/circt-synth/comb-lowering-lec.mlir +++ b/integration_test/circt-synth/comb-lowering-lec.mlir @@ -34,3 +34,31 @@ hw.module @mul(in %arg0: i3, in %arg1: i3, in %arg2: i3, out add: i3) { %0 = comb.mul %arg0, %arg1, %arg2 : i3 hw.output %0 : i3 } + +// RUN: circt-lec %t.mlir %s -c1=icmp_eq_ne -c2=icmp_eq_ne --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ICMP_EQ_NE +// COMB_ICMP_EQ_NE: c1 == c2 +hw.module @icmp_eq_ne(in %lhs: i3, in %rhs: i3, out out_eq: i1, out out_ne: i1) { + %eq = comb.icmp eq %lhs, %rhs : i3 + %ne = comb.icmp ne %lhs, %rhs : i3 + hw.output %eq, %ne : i1, i1 +} + +// RUN: circt-lec %t.mlir %s -c1=icmp_unsigned_compare -c2=icmp_unsigned_compare --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ICMP_UNSIGNED_COMPARE +// COMB_ICMP_UNSIGNED_COMPARE: c1 == c2 +hw.module @icmp_unsigned_compare(in %lhs: i3, in %rhs: i3, out out_ugt: i1, out out_uge: i1, out out_ult: i1, out out_ule: i1) { + %ugt = comb.icmp ugt %lhs, %rhs : i3 + %uge = comb.icmp uge %lhs, %rhs : i3 + %ult = comb.icmp ult %lhs, %rhs : i3 + %ule = comb.icmp ule %lhs, %rhs : i3 + hw.output %ugt, %uge, %ult, %ule : i1, i1, i1, i1 +} + +// RUN: circt-lec %t.mlir %s -c1=icmp_signed_compare -c2=icmp_signed_compare --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ICMP_SIGNED_COMPARE +// COMB_ICMP_SIGNED_COMPARE: c1 == c2 +hw.module @icmp_signed_compare(in %lhs: i3, in %rhs: i3, out out_sgt: i1, out out_sge: i1, out out_slt: i1, out out_sle: i1) { + %sgt = comb.icmp sgt %lhs, %rhs : i3 + %sge = comb.icmp sge %lhs, %rhs : i3 + %slt = comb.icmp slt %lhs, %rhs : i3 + %sle = comb.icmp sle %lhs, %rhs : i3 + hw.output %sgt, %sge, %slt, %sle : i1, i1, i1, i1 +} diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index 336512fcc72f..a54e1a01d3be 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -328,6 +328,119 @@ struct CombMulOpConversion : OpConversionPattern { } }; +struct CombICmpOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + static Value constructUnsignedCompare(ICmpOp op, ArrayRef aBits, + ArrayRef bBits, bool isLess, + bool includeEq, + ConversionPatternRewriter &rewriter) { + // Construct following unsigned comparison expressions. + // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0]) + // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0]) + // a >= b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] >= b[n-1:0]) + // a > b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] > b[n-1:0]) + Value acc = + rewriter.create(op.getLoc(), op.getType(), includeEq); + + for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) { + auto aBitXorBBit = + rewriter.createOrFold(op.getLoc(), aBit, bBit, true); + auto aEqualB = rewriter.createOrFold( + op.getLoc(), aBitXorBBit, true); + auto pred = rewriter.createOrFold( + op.getLoc(), aBit, bBit, isLess, !isLess); + + auto aBitAndBBit = rewriter.createOrFold( + op.getLoc(), ValueRange{aEqualB, acc}, true); + acc = rewriter.createOrFold(op.getLoc(), pred, aBitAndBBit, + true); + } + return acc; + } + + LogicalResult + matchAndRewrite(ICmpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + switch (op.getPredicate()) { + default: + return failure(); + + case ICmpPredicate::eq: + case ICmpPredicate::ceq: { + // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ... + auto xorOp = rewriter.createOrFold(op.getLoc(), lhs, rhs); + auto xorBits = extractBits(rewriter, xorOp); + SmallVector allInverts(xorBits.size(), true); + rewriter.replaceOpWithNewOp(op, xorBits, allInverts); + return success(); + } + + case ICmpPredicate::ne: + case ICmpPredicate::cne: { + // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ... + auto xorOp = rewriter.createOrFold(op.getLoc(), lhs, rhs); + rewriter.replaceOpWithNewOp(op, extractBits(rewriter, xorOp), + true); + return success(); + } + + case ICmpPredicate::uge: + case ICmpPredicate::ugt: + case ICmpPredicate::ule: + case ICmpPredicate::ult: { + bool isLess = op.getPredicate() == ICmpPredicate::ult || + op.getPredicate() == ICmpPredicate::ule; + bool includeEq = op.getPredicate() == ICmpPredicate::uge || + op.getPredicate() == ICmpPredicate::ule; + auto aBits = extractBits(rewriter, lhs); + auto bBits = extractBits(rewriter, rhs); + rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess, + includeEq, rewriter)); + return success(); + } + case ICmpPredicate::slt: + case ICmpPredicate::sle: + case ICmpPredicate::sgt: + case ICmpPredicate::sge: { + if (lhs.getType().getIntOrFloatBitWidth() == 0) + return rewriter.notifyMatchFailure( + op.getLoc(), "i0 signed comparison is unsupported"); + bool isLess = op.getPredicate() == ICmpPredicate::slt || + op.getPredicate() == ICmpPredicate::sle; + bool includeEq = op.getPredicate() == ICmpPredicate::sge || + op.getPredicate() == ICmpPredicate::sle; + + auto aBits = extractBits(rewriter, lhs); + auto bBits = extractBits(rewriter, rhs); + + // Get a sign bit + auto signA = aBits.back(); + auto signB = bBits.back(); + + // Compare magnitudes (all bits except sign) + auto sameSignResult = constructUnsignedCompare( + op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess, + includeEq, rewriter); + + // XOR of signs: true if signs are different + auto signsDiffer = + rewriter.create(op.getLoc(), signA, signB); + + // Result when signs are different + Value diffSignResult = isLess ? signA : signB; + + // Final result: choose based on whether signs differ + rewriter.replaceOpWithNewOp(op, signsDiffer, diffSignResult, + sameSignResult); + return success(); + } + } + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -350,9 +463,10 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { CombMuxOpConversion, // Arithmetic Ops CombAddOpConversion, CombSubOpConversion, CombMulOpConversion, + CombICmpOpConversion, // Variadic ops that must be lowered to binary operations - CombLowerVariadicOp, CombLowerVariadicOp, CombLowerVariadicOp>( - patterns.getContext()); + CombLowerVariadicOp, CombLowerVariadicOp, + CombLowerVariadicOp>(patterns.getContext()); } void ConvertCombToAIGPass::runOnOperation() { diff --git a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir index 59d956e953a0..d97012e21f9f 100644 --- a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir +++ b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir @@ -1,5 +1,5 @@ -// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux}))" | FileCheck %s -// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add}))" | FileCheck %s --check-prefix=ALLOW_ADD +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux},cse))" | FileCheck %s +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add},cse))" | FileCheck %s --check-prefix=ALLOW_ADD // CHECK-LABEL: @add @@ -46,3 +46,76 @@ hw.module @mul(in %lhs: i2, in %rhs: i2, out out: i2) { %0 = comb.mul %lhs, %rhs : i2 hw.output %0 : i2 } + +// CHECK-LABEL: @icmp_eq_ne +hw.module @icmp_eq_ne(in %lhs: i2, in %rhs: i2, out out_eq: i1, out out_ne: i1) { + %eq = comb.icmp eq %lhs, %rhs : i2 + %ne = comb.icmp ne %lhs, %rhs : i2 + // CHECK-NEXT: %[[XOR:.+]] = comb.xor %lhs, %rhs + // CHECK-NEXT: %[[XOR_0:.+]] = comb.extract %[[XOR]] from 0 : (i2) -> i1 + // CHECK-NEXT: %[[XOR_1:.+]] = comb.extract %[[XOR]] from 1 : (i2) -> i1 + // CHECK-NEXT: %[[EQ:.+]] = aig.and_inv not %[[XOR_0]], not %[[XOR_1]] + // CHECK-NEXT: %[[NEQ:.+]] = comb.or bin %[[XOR_0]], %[[XOR_1]] + // CHECK-NEXT: hw.output %[[EQ]], %[[NEQ]] + // CHECK-NEXT: } + hw.output %eq, %ne : i1, i1 +} + +// CHECK-LABEL: @icmp_unsigned_compare +hw.module @icmp_unsigned_compare(in %lhs: i2, in %rhs: i2, out out_ugt: i1, out out_uge: i1, out out_ult: i1, out out_ule: i1) { + %ugt = comb.icmp ugt %lhs, %rhs : i2 + %uge = comb.icmp uge %lhs, %rhs : i2 + %ult = comb.icmp ult %lhs, %rhs : i2 + %ule = comb.icmp ule %lhs, %rhs : i2 + // CHECK-NEXT: %[[LHS_0:.+]] = comb.extract %lhs from 0 : (i2) -> i1 + // CHECK-NEXT: %[[LHS_1:.+]] = comb.extract %lhs from 1 : (i2) -> i1 + // CHECK-NEXT: %[[RHS_0:.+]] = comb.extract %rhs from 0 : (i2) -> i1 + // CHECK-NEXT: %[[RHS_1:.+]] = comb.extract %rhs from 1 : (i2) -> i1 + // CHECK-NEXT: %[[LSB_NEQ:.+]] = comb.xor bin %[[LHS_0]], %[[RHS_0]] + // CHECK-NEXT: %[[LSB_GT:.+]] = aig.and_inv %[[LHS_0]], not %[[RHS_0]] + // CHECK-NEXT: %[[MSB_NEQ:.+]] = comb.xor bin %[[LHS_1]], %[[RHS_1]] + // CHECK-NEXT: %[[MSB_EQ:.+]] = aig.and_inv not %[[MSB_NEQ]] + // CHECK-NEXT: %[[MSB_GT:.+]] = aig.and_inv %[[LHS_1]], not %[[RHS_1]] + // CHECK-NEXT: %[[MSB_EQ_AND_LSB_GT:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_GT]] + // CHECK-NEXT: %[[UGT:.+]] = comb.or bin %[[MSB_GT]], %[[MSB_EQ_AND_LSB_GT]] + // CHECK-NEXT: %[[LSB_EQ:.+]] = aig.and_inv not %[[LSB_NEQ]] + // CHECK-NEXT: %[[LSB_UGE:.+]] = comb.or bin %[[LSB_GT]], %[[LSB_EQ]] + // CHECK-NEXT: %[[MSB_EQ_AND_LSB_UGE:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_UGE]] + // CHECK-NEXT: %[[UGE:.+]] = comb.or bin %[[MSB_GT]], %[[MSB_EQ_AND_LSB_UGE]] + // CHECK-NEXT: %[[LSB_LT:.+]] = aig.and_inv not %[[LHS_0]], %[[RHS_0]] + // CHECK-NEXT: %[[MSB_LT:.+]] = aig.and_inv not %[[LHS_1]], %[[RHS_1]] + // CHECK-NEXT: %[[MSB_EQ_AND_LSB_LT:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_LT]] + // CHECK-NEXT: %[[ULT:.+]] = comb.or bin %[[MSB_LT]], %[[MSB_EQ_AND_LSB_LT]] + // CHECK-NEXT: %[[LSB_LE:.+]] = comb.or bin %[[LSB_LT]], %[[LSB_EQ]] + // CHECK-NEXT: %[[MSB_EQ_AND_LSB_LE:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_LE]] + // CHECK-NEXT: %[[ULE:.+]] = comb.or bin %[[MSB_LT]], %[[MSB_EQ_AND_LSB_LE]] + // CHECK-NEXT: hw.output %[[UGT]], %[[UGE]], %[[ULT]], %[[ULE]] + // CHECK-NEXT: } + hw.output %ugt, %uge, %ult, %ule : i1, i1, i1, i1 +} + +// CHECK-LABEL: @icmp_signed_compare +hw.module @icmp_signed_compare(in %lhs: i2, in %rhs: i2, out out_sgt: i1, out out_sge: i1, out out_slt: i1, out out_sle: i1) { + %sgt = comb.icmp sgt %lhs, %rhs : i2 + %sge = comb.icmp sge %lhs, %rhs : i2 + %slt = comb.icmp slt %lhs, %rhs : i2 + %sle = comb.icmp sle %lhs, %rhs : i2 + // CHECK-NEXT: %[[LHS_0:.+]] = comb.extract %lhs from 0 : (i2) -> i1 + // CHECK-NEXT: %[[LHS_1:.+]] = comb.extract %lhs from 1 : (i2) -> i1 + // CHECK-NEXT: %[[RHS_0:.+]] = comb.extract %rhs from 0 : (i2) -> i1 + // CHECK-NEXT: %[[RHS_1:.+]] = comb.extract %rhs from 1 : (i2) -> i1 + // CHECK-NEXT: %[[LSB_NEQ:.+]] = comb.xor bin %[[LHS_0]], %[[RHS_0]] + // CHECK-NEXT: %[[LSB_GT:.+]] = aig.and_inv %[[LHS_0]], not %[[RHS_0]] + // CHECK-NEXT: %[[SIGN_NEQ:.+]] = comb.xor %[[LHS_1]], %[[RHS_1]] + // CHECK-NEXT: %[[SGT:.+]] = comb.mux %[[SIGN_NEQ]], %[[RHS_1]], %[[LSB_GT]] + // CHECK-NEXT: %[[LSB_EQ:.+]] = aig.and_inv not %[[LSB_NEQ]] + // CHECK-NEXT: %[[LSB_GE:.+]] = comb.or bin %[[LSB_GT]], %[[LSB_EQ]] + // CHECK-NEXT: %[[SGE:.+]] = comb.mux %[[SIGN_NEQ]], %[[RHS_1]], %[[LSB_GE]] + // CHECK-NEXT: %[[LSB_LT:.+]] = aig.and_inv not %[[LHS_0]], %[[RHS_0]] + // CHECK-NEXT: %[[SLT:.+]] = comb.mux %[[SIGN_NEQ]], %[[LHS_1]], %[[LSB_LT]] + // CHECK-NEXT: %[[LSB_LE:.+]] = comb.or bin %[[LSB_LT]], %[[LSB_EQ]] + // CHECK-NEXT: %[[SLE:.+]] = comb.mux %[[SIGN_NEQ]], %[[LHS_1]], %[[LSB_LE]] + // CHECK-NEXT: hw.output %[[SGT]], %[[SGE]], %[[SLT]], %[[SLE]] + // CHECK-NEXT: } + hw.output %sgt, %sge, %slt, %sle : i1, i1, i1, i1 +}