From 9af79669551d9b8bf09f258c5256df4d4a6494f9 Mon Sep 17 00:00:00 2001 From: Jiahan Xie Date: Tue, 9 Jul 2024 10:32:39 -0400 Subject: [PATCH] use ArrayAttr to store reference mapping so it's more neat; change the order of insertion point so that the number of the reference and the external memory align in the invoke --- include/circt/Dialect/Calyx/CalyxControl.td | 2 +- lib/Conversion/SCFToCalyx/SCFToCalyx.cpp | 31 +++++++++++-------- lib/Dialect/Calyx/CalyxOps.cpp | 24 ++++++++------ lib/Dialect/Calyx/Export/CalyxEmitter.cpp | 12 ++++--- .../Conversion/SCFToCalyx/convert_memory.mlir | 2 +- 5 files changed, 43 insertions(+), 28 deletions(-) diff --git a/include/circt/Dialect/Calyx/CalyxControl.td b/include/circt/Dialect/Calyx/CalyxControl.td index b63290afcd04..eb7eee3c60c9 100644 --- a/include/circt/Dialect/Calyx/CalyxControl.td +++ b/include/circt/Dialect/Calyx/CalyxControl.td @@ -450,7 +450,7 @@ def InvokeOp : CalyxOp<"invoke", [ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$ports, Variadic:$inputs, - DictionaryAttr:$refCellsMap, + ArrayAttr:$refCellsMap, ArrayAttr:$portNames, ArrayAttr:$inputNames); let results = (outs); diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index bbe92e21357e..46d05f286ac7 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -1403,7 +1403,7 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { SmallVector instancePorts; SmallVector inputPorts; - NamedAttrList refCells; + SmallVector refCells; for (auto operandEnum : enumerate(callSchedPtr->callOp.getOperands())) { auto operand = operandEnum.value(); auto index = operandEnum.index(); @@ -1414,12 +1414,16 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { auto memOpNameAttr = SymbolRefAttr::get(rewriter.getContext(), memOpName); Value argI = calleeFunc.getArgument(index); - if (isa(argI.getType())) - refCells.append(NamedAttribute( + if (isa(argI.getType())) { + NamedAttrList namedAttrList; + namedAttrList.append( rewriter.getStringAttr( instanceOpLoweringState->getMemoryInterface(argI) .memName()), - memOpNameAttr)); + memOpNameAttr); + refCells.push_back( + DictionaryAttr::get(rewriter.getContext(), namedAttrList)); + } } else { inputPorts.push_back(operand); } @@ -1427,8 +1431,8 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { llvm::copy(instanceOp.getResults().take_front(inputPorts.size()), std::back_inserter(instancePorts)); - DictionaryAttr refCellsAttr = - refCells.getDictionary(rewriter.getContext()); + ArrayAttr refCellsAttr = + ArrayAttr::get(rewriter.getContext(), refCells); rewriter.create( instanceOp.getLoc(), instanceOp.getSymName(), instancePorts, @@ -1811,7 +1815,8 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase { } FunctionType callerFnType = caller.getFunctionType(); - SmallVector updatedCallerArgTypes(callerFnType.getInputs()); + SmallVector updatedCallerArgTypes( + caller.getFunctionType().getInputs()); updatedCallerArgTypes.append(nonMemRefCalleeArgTypes.begin(), nonMemRefCalleeArgTypes.end()); caller.setType(FunctionType::get(caller.getContext(), updatedCallerArgTypes, @@ -1824,12 +1829,10 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase { SmallVector extraMemRefArgTypes; SmallVector extraMemRefOperands; SmallVector opsToModify; - for (auto &block : callee.getBody()) { - for (auto &op : block) { - if (isa(op) || isa(op) || - isa(op)) - opsToModify.push_back(&op); - } + for (auto &op : callee.getBody().getOps()) { + if (isa(op) || isa(op) || + isa(op)) + opsToModify.push_back(&op); } // Replace `alloc`/`getGlobal` in the original top-level with new @@ -1866,6 +1869,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase { unsigned otherArgsCount = 0; SmallVector calleeArgFnOperands; + builder.setInsertionPointToStart(callerEntryBlock); for (auto arg : callee.getArguments().take_front(originalCalleeArgNum)) { if (isa(arg.getType())) { auto memrefType = cast(arg.getType()); @@ -1885,6 +1889,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase { SymbolRefAttr::get(builder.getContext(), callee.getSymName()); auto resultTypes = callee.getResultTypes(); + builder.setInsertionPointToEnd(callerEntryBlock); builder.create(caller.getLoc(), calleeName, resultTypes, fnOperands); } diff --git a/lib/Dialect/Calyx/CalyxOps.cpp b/lib/Dialect/Calyx/CalyxOps.cpp index 047fda4563f4..1675ef677ac4 100644 --- a/lib/Dialect/Calyx/CalyxOps.cpp +++ b/lib/Dialect/Calyx/CalyxOps.cpp @@ -2606,26 +2606,28 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { FlatSymbolRefAttr callee = FlatSymbolRefAttr::get(componentName); SMLoc loc = parser.getCurrentLocation(); - SmallVector refCellSymbols; + SmallVector refCells; if (succeeded(parser.parseOptionalLSquare())) { if (parser.parseCommaSeparatedList([&]() -> ParseResult { std::string refCellName; std::string externalMem; + NamedAttrList refCellAttr; if (parser.parseKeywordOrString(&refCellName) || parser.parseEqual() || parser.parseKeywordOrString(&externalMem)) return failure(); auto externalMemAttr = SymbolRefAttr::get(parser.getContext(), externalMem); - refCellSymbols.push_back( - NamedAttribute(StringAttr::get(parser.getContext(), refCellName), - externalMemAttr)); + refCellAttr.append(StringAttr::get(parser.getContext(), refCellName), + externalMemAttr); + refCells.push_back( + DictionaryAttr::get(parser.getContext(), refCellAttr)); return success(); }) || parser.parseRSquare()) return failure(); } result.addAttribute("refCellsMap", - DictionaryAttr::get(parser.getContext(), refCellSymbols)); + ArrayAttr::get(parser.getContext(), refCells)); result.addAttribute("callee", callee); if (parseParameterList(parser, result, ports, inputs, portNames, inputNames, @@ -2645,10 +2647,14 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { void InvokeOp::print(OpAsmPrinter &p) { p << " @" << getCallee() << "["; auto refCellNamesMap = getRefCellsMap(); - llvm::interleaveComma(refCellNamesMap, p, [&](auto arg) { - auto refCellName = arg.getName().str(); - auto externalMem = cast(arg.getValue()).getValue(); - p << refCellName << " = " << externalMem; + llvm::interleaveComma(refCellNamesMap, p, [&](Attribute attr) { + auto dictAttr = cast(attr); + llvm::interleaveComma(dictAttr, p, [&](NamedAttribute namedAttr) { + auto refCellName = namedAttr.getName().str(); + auto externalMem = + cast(namedAttr.getValue()).getValue(); + p << refCellName << " = " << externalMem; + }); }); p << "]("; diff --git a/lib/Dialect/Calyx/Export/CalyxEmitter.cpp b/lib/Dialect/Calyx/Export/CalyxEmitter.cpp index 2e3e1d9f0ca7..bb82724d882c 100644 --- a/lib/Dialect/Calyx/Export/CalyxEmitter.cpp +++ b/lib/Dialect/Calyx/Export/CalyxEmitter.cpp @@ -856,10 +856,14 @@ void Emitter::emitInvoke(InvokeOp invoke) { auto refCellsMap = invoke.getRefCellsMap(); if (!refCellsMap.empty()) { os << "["; - llvm::interleaveComma(refCellsMap, os, [&](auto refCell) { - auto refCellName = refCell.getName().str(); - auto externalMem = cast(refCell.getValue()).getValue(); - os << refCellName << " = " << externalMem; + llvm::interleaveComma(refCellsMap, os, [&](Attribute attr) { + auto dictAttr = cast(attr); + llvm::interleaveComma(dictAttr, os, [&](NamedAttribute namedAttr) { + auto refCellName = namedAttr.getName().str(); + auto externalMem = + cast(namedAttr.getValue()).getValue(); + os << refCellName << " = " << externalMem; + }); }); os << "]"; } diff --git a/test/Conversion/SCFToCalyx/convert_memory.mlir b/test/Conversion/SCFToCalyx/convert_memory.mlir index dded99ac20cc..1de93cafb342 100644 --- a/test/Conversion/SCFToCalyx/convert_memory.mlir +++ b/test/Conversion/SCFToCalyx/convert_memory.mlir @@ -833,7 +833,7 @@ module { // CHECK: calyx.control { // CHECK: calyx.seq { // CHECK: calyx.enable @init_main_instance -// CHECK: calyx.invoke @main_instance[arg_mem_0 = mem_1, arg_mem_1 = mem_0]() -> () +// CHECK: calyx.invoke @main_instance[arg_mem_0 = mem_0, arg_mem_1 = mem_1]() -> () // CHECK: } // CHECK: } // CHECK: } {toplevel}