Skip to content

Commit

Permalink
use ArrayAttr to store reference mapping so it's more neat; change th…
Browse files Browse the repository at this point in the history
…e order of insertion point so that the number of the reference and the external memory align in the invoke
  • Loading branch information
jiahanxie353 committed Jul 9, 2024
1 parent 134da0f commit 9af7966
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 28 deletions.
2 changes: 1 addition & 1 deletion include/circt/Dialect/Calyx/CalyxControl.td
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def InvokeOp : CalyxOp<"invoke", [
let arguments = (ins FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$ports,
Variadic<AnyType>:$inputs,
DictionaryAttr:$refCellsMap,
ArrayAttr:$refCellsMap,
ArrayAttr:$portNames,
ArrayAttr:$inputNames);
let results = (outs);
Expand Down
31 changes: 18 additions & 13 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {

SmallVector<Value, 4> instancePorts;
SmallVector<Value, 4> inputPorts;
NamedAttrList refCells;
SmallVector<Attribute, 4> refCells;
for (auto operandEnum : enumerate(callSchedPtr->callOp.getOperands())) {
auto operand = operandEnum.value();
auto index = operandEnum.index();
Expand All @@ -1414,21 +1414,25 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
auto memOpNameAttr =
SymbolRefAttr::get(rewriter.getContext(), memOpName);
Value argI = calleeFunc.getArgument(index);
if (isa<MemRefType>(argI.getType()))
refCells.append(NamedAttribute(
if (isa<MemRefType>(argI.getType())) {
NamedAttrList namedAttrList;
namedAttrList.append(
rewriter.getStringAttr(
instanceOpLoweringState->getMemoryInterface(argI)
.memName()),
memOpNameAttr));
memOpNameAttr);
refCells.push_back(
DictionaryAttr::get(rewriter.getContext(), namedAttrList));
}
} else {
inputPorts.push_back(operand);
}
}
llvm::copy(instanceOp.getResults().take_front(inputPorts.size()),
std::back_inserter(instancePorts));

DictionaryAttr refCellsAttr =
refCells.getDictionary(rewriter.getContext());
ArrayAttr refCellsAttr =
ArrayAttr::get(rewriter.getContext(), refCells);

rewriter.create<calyx::InvokeOp>(
instanceOp.getLoc(), instanceOp.getSymName(), instancePorts,
Expand Down Expand Up @@ -1811,7 +1815,8 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
}

FunctionType callerFnType = caller.getFunctionType();
SmallVector<Type, 4> updatedCallerArgTypes(callerFnType.getInputs());
SmallVector<Type, 4> updatedCallerArgTypes(
caller.getFunctionType().getInputs());
updatedCallerArgTypes.append(nonMemRefCalleeArgTypes.begin(),
nonMemRefCalleeArgTypes.end());
caller.setType(FunctionType::get(caller.getContext(), updatedCallerArgTypes,
Expand All @@ -1824,12 +1829,10 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
SmallVector<Type, 4> extraMemRefArgTypes;
SmallVector<Value, 4> extraMemRefOperands;
SmallVector<Operation *, 4> opsToModify;
for (auto &block : callee.getBody()) {
for (auto &op : block) {
if (isa<memref::AllocaOp>(op) || isa<memref::AllocOp>(op) ||
isa<memref::GetGlobalOp>(op))
opsToModify.push_back(&op);
}
for (auto &op : callee.getBody().getOps()) {
if (isa<memref::AllocaOp>(op) || isa<memref::AllocOp>(op) ||
isa<memref::GetGlobalOp>(op))
opsToModify.push_back(&op);
}

// Replace `alloc`/`getGlobal` in the original top-level with new
Expand Down Expand Up @@ -1866,6 +1869,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {

unsigned otherArgsCount = 0;
SmallVector<Value, 4> calleeArgFnOperands;
builder.setInsertionPointToStart(callerEntryBlock);
for (auto arg : callee.getArguments().take_front(originalCalleeArgNum)) {
if (isa<MemRefType>(arg.getType())) {
auto memrefType = cast<MemRefType>(arg.getType());
Expand All @@ -1885,6 +1889,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
SymbolRefAttr::get(builder.getContext(), callee.getSymName());
auto resultTypes = callee.getResultTypes();

builder.setInsertionPointToEnd(callerEntryBlock);
builder.create<CallOp>(caller.getLoc(), calleeName, resultTypes,
fnOperands);
}
Expand Down
24 changes: 15 additions & 9 deletions lib/Dialect/Calyx/CalyxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2606,26 +2606,28 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
FlatSymbolRefAttr callee = FlatSymbolRefAttr::get(componentName);
SMLoc loc = parser.getCurrentLocation();

SmallVector<NamedAttribute, 4> refCellSymbols;
SmallVector<Attribute, 4> refCells;
if (succeeded(parser.parseOptionalLSquare())) {
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
std::string refCellName;
std::string externalMem;
NamedAttrList refCellAttr;
if (parser.parseKeywordOrString(&refCellName) ||
parser.parseEqual() || parser.parseKeywordOrString(&externalMem))
return failure();
auto externalMemAttr =
SymbolRefAttr::get(parser.getContext(), externalMem);
refCellSymbols.push_back(
NamedAttribute(StringAttr::get(parser.getContext(), refCellName),
externalMemAttr));
refCellAttr.append(StringAttr::get(parser.getContext(), refCellName),
externalMemAttr);
refCells.push_back(
DictionaryAttr::get(parser.getContext(), refCellAttr));
return success();
}) ||
parser.parseRSquare())
return failure();
}
result.addAttribute("refCellsMap",
DictionaryAttr::get(parser.getContext(), refCellSymbols));
ArrayAttr::get(parser.getContext(), refCells));

result.addAttribute("callee", callee);
if (parseParameterList(parser, result, ports, inputs, portNames, inputNames,
Expand All @@ -2645,10 +2647,14 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
void InvokeOp::print(OpAsmPrinter &p) {
p << " @" << getCallee() << "[";
auto refCellNamesMap = getRefCellsMap();
llvm::interleaveComma(refCellNamesMap, p, [&](auto arg) {
auto refCellName = arg.getName().str();
auto externalMem = cast<FlatSymbolRefAttr>(arg.getValue()).getValue();
p << refCellName << " = " << externalMem;
llvm::interleaveComma(refCellNamesMap, p, [&](Attribute attr) {
auto dictAttr = cast<DictionaryAttr>(attr);
llvm::interleaveComma(dictAttr, p, [&](NamedAttribute namedAttr) {
auto refCellName = namedAttr.getName().str();
auto externalMem =
cast<FlatSymbolRefAttr>(namedAttr.getValue()).getValue();
p << refCellName << " = " << externalMem;
});
});
p << "](";

Expand Down
12 changes: 8 additions & 4 deletions lib/Dialect/Calyx/Export/CalyxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,10 +856,14 @@ void Emitter::emitInvoke(InvokeOp invoke) {
auto refCellsMap = invoke.getRefCellsMap();
if (!refCellsMap.empty()) {
os << "[";
llvm::interleaveComma(refCellsMap, os, [&](auto refCell) {
auto refCellName = refCell.getName().str();
auto externalMem = cast<FlatSymbolRefAttr>(refCell.getValue()).getValue();
os << refCellName << " = " << externalMem;
llvm::interleaveComma(refCellsMap, os, [&](Attribute attr) {
auto dictAttr = cast<DictionaryAttr>(attr);
llvm::interleaveComma(dictAttr, os, [&](NamedAttribute namedAttr) {
auto refCellName = namedAttr.getName().str();
auto externalMem =
cast<FlatSymbolRefAttr>(namedAttr.getValue()).getValue();
os << refCellName << " = " << externalMem;
});
});
os << "]";
}
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/SCFToCalyx/convert_memory.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ module {
// CHECK: calyx.control {
// CHECK: calyx.seq {
// CHECK: calyx.enable @init_main_instance
// CHECK: calyx.invoke @main_instance[arg_mem_0 = mem_1, arg_mem_1 = mem_0]() -> ()
// CHECK: calyx.invoke @main_instance[arg_mem_0 = mem_0, arg_mem_1 = mem_1]() -> ()
// CHECK: }
// CHECK: }
// CHECK: } {toplevel}
Expand Down

0 comments on commit 9af7966

Please sign in to comment.