From 5e6bdceb9fcca40b6af6681c3c67974becf3888a Mon Sep 17 00:00:00 2001 From: TaoBi22 Date: Mon, 18 Sep 2023 17:57:18 +0100 Subject: [PATCH] Fix state propagation and add seq test --- include/circt/LogicalEquivalence/Circuit.h | 22 ++++++++-------- integration_test/circt-mc/seq.mlir | 15 +++++++++++ lib/LogicalEquivalence/Circuit.cpp | 29 +++++++++++----------- 3 files changed, 40 insertions(+), 26 deletions(-) create mode 100644 integration_test/circt-mc/seq.mlir diff --git a/include/circt/LogicalEquivalence/Circuit.h b/include/circt/LogicalEquivalence/Circuit.h index 5cb7aa6c337c..7a792b6a64a8 100644 --- a/include/circt/LogicalEquivalence/Circuit.h +++ b/include/circt/LogicalEquivalence/Circuit.h @@ -113,8 +113,7 @@ class Solver::Circuit { /// over a range of operands. void variadicOperation( mlir::Value result, mlir::OperandRange operands, - llvm::function_ref - operation); + std::function operation); /// Returns the expression allocated for the input value in the logical /// backend if one has been allocated - otherwise allocates and returns a new /// expression @@ -135,8 +134,8 @@ class Solver::Circuit { /// transform void applyCombVariadicOperation( mlir::Value, - std::pair>); + std::pair>); /// Push solver constraints assigning registers and inputs to their current /// state @@ -188,16 +187,15 @@ class Solver::Circuit { /// A type to represent the different representations of combinational /// transforms using TransformVariant = std::variant< - std::pair>, + std::pair>, std::pair, - llvm::function_ref>, - std::pair< - std::tuple, - llvm::function_ref>, + std::function>, + std::pair, + std::function>, std::pair, - llvm::function_ref>>; + std::function>>; /// A map from wire values to their corresponding transformations. llvm::DenseMap combTransformTable; diff --git a/integration_test/circt-mc/seq.mlir b/integration_test/circt-mc/seq.mlir new file mode 100644 index 000000000000..dbbef8383d4a --- /dev/null +++ b/integration_test/circt-mc/seq.mlir @@ -0,0 +1,15 @@ +// These tests will be only enabled if circt-mc is built. +// REQUIRES: circt-mc + +// RUN: circt-mc %s -b 10 --module ClkProp | FileCheck %s --check-prefix=CLKPROP +// CLKPROP: Success! + +hw.module @ClkProp(%i0: i1, %clk: i1) { + %reg = seq.compreg %i0, %clk : i1 + // Condition (equivalent to %clk -> %reg == %i0) + %c-1_i1 = hw.constant -1 : i1 + %nclk = comb.xor bin %clk, %c-1_i1 : i1 + %eq = comb.icmp bin eq %i0, %reg : i1 + %imp = comb.or bin %nclk, %eq : i1 + verif.assert %imp : i1 +} diff --git a/lib/LogicalEquivalence/Circuit.cpp b/lib/LogicalEquivalence/Circuit.cpp index 8520a480f1e0..3287b9bca907 100644 --- a/lib/LogicalEquivalence/Circuit.cpp +++ b/lib/LogicalEquivalence/Circuit.cpp @@ -439,8 +439,7 @@ void Solver::Circuit::performXor(Value result, OperandRange operands) { /// over a range of operands. void Solver::Circuit::variadicOperation( Value result, OperandRange operands, - llvm::function_ref - operation) { + std::function operation) { // Allocate operands if unallocated LLVM_DEBUG(lec::dbgs() << "variadic operation\n"); lec::Scope indent; @@ -466,6 +465,7 @@ void Solver::Circuit::variadicOperation( ++it; } constrainResult(result, varOp); + combTransformTable.insert(std::pair(result, std::pair(operands, operation))); } /// Allocates an IR value in the logical backend and returns its representing @@ -566,6 +566,7 @@ void Solver::Circuit::constrainResult(Value &result, z3::expr &expr) { LLVM_DEBUG(lec::dbgs() << constraint.to_string() << "\n"); } solver.solver.add(constraint); + wires.push_back(result); } /// Convert from bitvector to bool sort. @@ -732,26 +733,26 @@ void Solver::Circuit::applyCombUpdates() { auto wireTransform = wireTransformPair->second; if (auto *transform = std::get_if>>( + std::function>>( &wireTransform)) { applyCombVariadicOperation(wire, *transform); } else if (auto *transform = std::get_if< std::pair, - llvm::function_ref>>( + std::function>>( &wireTransform)) { mlir::Value operand = std::get<0>(transform->first); - llvm::function_ref transformFunc = + std::function transformFunc = transform->second; z3::expr operandExpr = stateTable.find(operand)->second; stateTable.find(wire)->second = transformFunc(operandExpr); } else if (auto *transform = std::get_if< std::pair, - llvm::function_ref>>( + std::function>>( &wireTransform)) { mlir::Value firstOperand = std::get<0>(transform->first); mlir::Value secondOperand = std::get<1>(transform->first); - llvm::function_ref + std::function transformFunc = transform->second; z3::expr firstOperandExpr = stateTable.find(firstOperand)->second; z3::expr secondOperandExpr = stateTable.find(secondOperand)->second; @@ -759,14 +760,14 @@ void Solver::Circuit::applyCombUpdates() { transformFunc(firstOperandExpr, secondOperandExpr); } else if (auto *transform = std::get_if, - llvm::function_ref>>( + std::function>>( &wireTransform)) { mlir::Value firstOperand = std::get<0>(transform->first); mlir::Value secondOperand = std::get<1>(transform->first); mlir::Value thirdOperand = std::get<2>(transform->first); - llvm::function_ref + std::function transformFunc = transform->second; z3::expr firstOperandExpr = stateTable.find(firstOperand)->second; z3::expr secondOperandExpr = stateTable.find(secondOperand)->second; @@ -782,12 +783,12 @@ void Solver::Circuit::applyCombUpdates() { void Solver::Circuit::applyCombVariadicOperation( mlir::Value result, std::pair> + std::function> operationPair) { LLVM_DEBUG(lec::dbgs() << "comb variadic operation\n"); lec::Scope indent; mlir::OperandRange operands = operationPair.first; - llvm::function_ref operation = + std::function operation = operationPair.second; // Vacuous base case. auto it = operands.begin();