Skip to content

Commit

Permalink
add new conditional overwrite template
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Oct 26, 2023
1 parent 868f7e9 commit 7bd5c72
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 2 deletions.
3 changes: 2 additions & 1 deletion rtlrepair.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
'replace_literals': replace_literals,
'assign_const': assign_const,
'add_inversions': add_inversions,
'replace_variables': replace_variables
'replace_variables': replace_variables,
'conditional_overwrite': conditional_overwrite,
}
_default_templates = ['replace_literals', 'assign_const', 'add_inversions', 'replace_variables']

Expand Down
3 changes: 2 additions & 1 deletion rtlrepair/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from rtlrepair.templates.add_inversions import add_inversions
from rtlrepair.templates.replace_literals import replace_literals
from rtlrepair.templates.replace_variables import replace_variables
from rtlrepair.templates.assign_const import assign_const
from rtlrepair.templates.assign_const import assign_const
from rtlrepair.templates.conditional_overwrite import conditional_overwrite
12 changes: 12 additions & 0 deletions rtlrepair/templates/assign_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self):
self.assigned_vars = set()
self.blocking_count = 0
self.non_blocking_count = 0
self.conditions = []
self.case_inputs = []

def run(self, proc: vast.Always):
self.visit(proc)
Expand All @@ -121,3 +123,13 @@ def visit_NonblockingSubstitution(self, node: vast.NonblockingSubstitution):
def visit_ForStatement(self, node: vast.ForStatement):
# ignore the condition, pre and post of the for statement
self.visit(node.statement)

def visit_IfStatement(self, node: vast.IfStatement):
self.conditions.append(node.cond)
self.visit(node.true_statement)
self.visit(node.false_statement)

def visit_CaseStatement(self, node: vast.CaseStatement):
self.case_inputs.append(node.comp)
for cc in node.caselist:
self.visit(cc)
76 changes: 76 additions & 0 deletions rtlrepair/templates/conditional_overwrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2023 The Regents of the University of California
# released under BSD 3-Clause License
# author: Kevin Laeufer <[email protected]>


from rtlrepair.repair import RepairTemplate
from rtlrepair.templates.assign_const import ProcessAnalyzer
from rtlrepair.types import InferWidths
from rtlrepair.utils import Namespace, ensure_block
import pyverilog.vparser.ast as vast

def conditional_overwrite(ast: vast.Source):
namespace = Namespace(ast)
infer = InferWidths()
infer.run(ast)
repl = ConditionalOverwrite(infer.widths)
repl.apply(namespace, ast)
return repl.blockified


class ConditionalOverwrite(RepairTemplate):
def __init__(self, widths):
super().__init__(name="conditional_overwrite")
self.widths = widths
self.use_blocking = False
self.assigned_vars = []
# we use this list to track which new blocks we introduced in order to minimize the diff between
# buggy and repaired version
self.blockified = []

def visit_Always(self, node: vast.Always):
analysis = ProcessAnalyzer()
analysis.run(node)
if analysis.non_blocking_count > 0 and analysis.blocking_count > 0:
print("WARN: single always process seems to mix blocking and non-blocking assignment. Skipping.")
return node
# note: we are ignoring pointer for now since these might contain loop vars that may not always be in scope..
assigned_vars = [var for var in analysis.assigned_vars if isinstance(var, vast.Identifier)]
self.use_blocking = analysis.blocking_count > 0
# add conditional overwrites to the end of the process
stmts = []
for var in assigned_vars:
cond = self.gen_condition(analysis.conditions, analysis.case_inputs)
assignment = self.make_assignment(var)
inner = vast.IfStatement(cond, assignment, None)
stmts.append(self.make_change_stmt(inner, 0))
# append statements
node.statement = ensure_block(node.statement, self.blockified)
node.statement.statements = tuple(list(node.statement.statements) + stmts)
return node

def gen_condition(self, conditions: list, case_inputs: list) -> vast.Node:
atoms = conditions + [vast.Eq(ci, vast.Identifier(self.make_synth_var(self.widths[ci]))) for ci in case_inputs]
# atoms can be inverted
atoms_or_inv = [self.make_synth_choice(aa, vast.Ulnot(aa)) for aa in atoms]
# atoms do not need to be used
tru = vast.IntConst("1'b1")
atoms_optional = [self.make_change(aa, tru) for aa in atoms_or_inv]
# combine all atoms together
node = atoms_optional[0]
for aa in atoms_optional[1:]:
node = vast.And(node, aa)
return node

def make_synth_choice(self, a, b):
name = self.make_synth_var(1)
return vast.Cond(vast.Identifier(name), a, b)

def make_assignment(self, var):
width = self.widths[var]
const = vast.Identifier(self.make_synth_var(width))
if self.use_blocking:
assign = vast.BlockingSubstitution(vast.Lvalue(var), vast.Rvalue(const))
else:
assign = vast.NonblockingSubstitution(vast.Lvalue(var), vast.Rvalue(const))
return assign

0 comments on commit 7bd5c72

Please sign in to comment.