Skip to content

Commit

Permalink
Fix state propagation and add seq test
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoBi22 committed Sep 18, 2023
1 parent d273306 commit 5e6bdce
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 26 deletions.
22 changes: 10 additions & 12 deletions include/circt/LogicalEquivalence/Circuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ class Solver::Circuit {
/// over a range of operands.
void variadicOperation(
mlir::Value result, mlir::OperandRange operands,
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)>
operation);
std::function<z3::expr(const z3::expr &, const z3::expr &)> operation);
/// Returns the expression allocated for the input value in the logical
/// backend if one has been allocated - otherwise allocates and returns a new
/// expression
Expand All @@ -135,8 +134,8 @@ class Solver::Circuit {
/// transform
void applyCombVariadicOperation(
mlir::Value,
std::pair<mlir::OperandRange, llvm::function_ref<z3::expr(
const z3::expr &, const z3::expr &)>>);
std::pair<mlir::OperandRange,
std::function<z3::expr(const z3::expr &, const z3::expr &)>>);

/// Push solver constraints assigning registers and inputs to their current
/// state
Expand Down Expand Up @@ -188,16 +187,15 @@ class Solver::Circuit {
/// A type to represent the different representations of combinational
/// transforms
using TransformVariant = std::variant<
std::pair<mlir::OperandRange, llvm::function_ref<z3::expr(
const z3::expr &, const z3::expr &)>>,
std::pair<mlir::OperandRange,
std::function<z3::expr(const z3::expr &, const z3::expr &)>>,
std::pair<std::tuple<mlir::Value>,
llvm::function_ref<z3::expr(const z3::expr &)>>,
std::pair<
std::tuple<mlir::Value, mlir::Value>,
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)>>,
std::function<z3::expr(const z3::expr &)>>,
std::pair<std::tuple<mlir::Value, mlir::Value>,
std::function<z3::expr(const z3::expr &, const z3::expr &)>>,
std::pair<std::tuple<mlir::Value, mlir::Value, mlir::Value>,
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &,
const z3::expr &)>>>;
std::function<z3::expr(const z3::expr &, const z3::expr &,
const z3::expr &)>>>;
/// A map from wire values to their corresponding transformations.
llvm::DenseMap<mlir::Value, TransformVariant> combTransformTable;

Expand Down
15 changes: 15 additions & 0 deletions integration_test/circt-mc/seq.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// These tests will be only enabled if circt-mc is built.
// REQUIRES: circt-mc

// RUN: circt-mc %s -b 10 --module ClkProp | FileCheck %s --check-prefix=CLKPROP
// CLKPROP: Success!

hw.module @ClkProp(%i0: i1, %clk: i1) {
%reg = seq.compreg %i0, %clk : i1
// Condition (equivalent to %clk -> %reg == %i0)
%c-1_i1 = hw.constant -1 : i1
%nclk = comb.xor bin %clk, %c-1_i1 : i1
%eq = comb.icmp bin eq %i0, %reg : i1
%imp = comb.or bin %nclk, %eq : i1
verif.assert %imp : i1
}
29 changes: 15 additions & 14 deletions lib/LogicalEquivalence/Circuit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,7 @@ void Solver::Circuit::performXor(Value result, OperandRange operands) {
/// over a range of operands.
void Solver::Circuit::variadicOperation(
Value result, OperandRange operands,
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)>
operation) {
std::function<z3::expr(const z3::expr &, const z3::expr &)> operation) {
// Allocate operands if unallocated
LLVM_DEBUG(lec::dbgs() << "variadic operation\n");
lec::Scope indent;
Expand All @@ -466,6 +465,7 @@ void Solver::Circuit::variadicOperation(
++it;
}
constrainResult(result, varOp);
combTransformTable.insert(std::pair(result, std::pair(operands, operation)));
}

/// Allocates an IR value in the logical backend and returns its representing
Expand Down Expand Up @@ -566,6 +566,7 @@ void Solver::Circuit::constrainResult(Value &result, z3::expr &expr) {
LLVM_DEBUG(lec::dbgs() << constraint.to_string() << "\n");
}
solver.solver.add(constraint);
wires.push_back(result);
}

/// Convert from bitvector to bool sort.
Expand Down Expand Up @@ -732,41 +733,41 @@ void Solver::Circuit::applyCombUpdates() {
auto wireTransform = wireTransformPair->second;
if (auto *transform = std::get_if<std::pair<
mlir::OperandRange,
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)>>>(
std::function<z3::expr(const z3::expr &, const z3::expr &)>>>(
&wireTransform)) {
applyCombVariadicOperation(wire, *transform);
} else if (auto *transform = std::get_if<
std::pair<std::tuple<mlir::Value>,
llvm::function_ref<z3::expr(const z3::expr &)>>>(
std::function<z3::expr(const z3::expr &)>>>(
&wireTransform)) {
mlir::Value operand = std::get<0>(transform->first);
llvm::function_ref<z3::expr(const z3::expr &)> transformFunc =
std::function<z3::expr(const z3::expr &)> transformFunc =
transform->second;
z3::expr operandExpr = stateTable.find(operand)->second;
stateTable.find(wire)->second = transformFunc(operandExpr);
} else if (auto *transform = std::get_if<
std::pair<std::tuple<mlir::Value, mlir::Value>,
llvm::function_ref<z3::expr(const z3::expr &,
const z3::expr &)>>>(
std::function<z3::expr(const z3::expr &,
const z3::expr &)>>>(
&wireTransform)) {
mlir::Value firstOperand = std::get<0>(transform->first);
mlir::Value secondOperand = std::get<1>(transform->first);
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)>
std::function<z3::expr(const z3::expr &, const z3::expr &)>
transformFunc = transform->second;
z3::expr firstOperandExpr = stateTable.find(firstOperand)->second;
z3::expr secondOperandExpr = stateTable.find(secondOperand)->second;
stateTable.find(wire)->second =
transformFunc(firstOperandExpr, secondOperandExpr);
} else if (auto *transform = std::get_if<std::pair<
std::tuple<mlir::Value, mlir::Value, mlir::Value>,
llvm::function_ref<z3::expr(
const z3::expr &, const z3::expr &, const z3::expr &)>>>(
std::function<z3::expr(const z3::expr &, const z3::expr &,
const z3::expr &)>>>(
&wireTransform)) {
mlir::Value firstOperand = std::get<0>(transform->first);
mlir::Value secondOperand = std::get<1>(transform->first);
mlir::Value thirdOperand = std::get<2>(transform->first);
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &,
const z3::expr &)>
std::function<z3::expr(const z3::expr &, const z3::expr &,
const z3::expr &)>
transformFunc = transform->second;
z3::expr firstOperandExpr = stateTable.find(firstOperand)->second;
z3::expr secondOperandExpr = stateTable.find(secondOperand)->second;
Expand All @@ -782,12 +783,12 @@ void Solver::Circuit::applyCombUpdates() {
void Solver::Circuit::applyCombVariadicOperation(
mlir::Value result,
std::pair<mlir::OperandRange,
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)>>
std::function<z3::expr(const z3::expr &, const z3::expr &)>>
operationPair) {
LLVM_DEBUG(lec::dbgs() << "comb variadic operation\n");
lec::Scope indent;
mlir::OperandRange operands = operationPair.first;
llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)> operation =
std::function<z3::expr(const z3::expr &, const z3::expr &)> operation =
operationPair.second;
// Vacuous base case.
auto it = operands.begin();
Expand Down

0 comments on commit 5e6bdce

Please sign in to comment.