Skip to content

Commit

Permalink
add more comments for function description; use llvm join for string …
Browse files Browse the repository at this point in the history
…concat
  • Loading branch information
jiahanxie353 committed Jun 12, 2024
1 parent 8157945 commit d581ee4
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,8 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
std::string funcName = "func_" + funcOp.getSymName().str();
rewriter.modifyOpInPlace(funcOp, [&]() { funcOp.setSymName(funcName); });

/// Mark this component as the toplevel if it's the top-level function of the module.
/// Mark this component as the toplevel if it's the top-level function of
/// the module.
if (compOp.getName() == loweringState().getTopLevelFunction())
compOp->setAttr("toplevel", rewriter.getUnitAttr());

Expand All @@ -1098,7 +1099,8 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
unsigned extMemCounter = 0;
for (auto arg : enumerate(funcOp.getArguments())) {
if (isa<MemRefType>(arg.value().getType())) {
std::string memName = "arg_mem" + std::to_string(extMemCounter++);
std::string memName =
llvm::join_items("_", "arg_mem", std::to_string(extMemCounter++));

rewriter.setInsertionPointToStart(compOp.getBodyBlock());
MemRefType memtype = cast<MemRefType>(arg.value().getType());
Expand Down Expand Up @@ -1661,7 +1663,7 @@ class SCFToCalyxPass : public SCFToCalyxBase<SCFToCalyxPass> {
}
}

return createOptNewTopLevelFcn(moduleOp, topLevelFunction);
return createOptNewTopLevelFn(moduleOp, topLevelFunction);
}

struct LoweringPattern {
Expand Down Expand Up @@ -1758,11 +1760,12 @@ class SCFToCalyxPass : public SCFToCalyxBase<SCFToCalyxPass> {
LogicalResult partialPatternRes;
std::shared_ptr<calyx::CalyxLoweringState> loweringState = nullptr;

FuncOp createNewTopLevelFcn(ModuleOp moduleOp, std::string &baseName) {
/// Creates a new new top-level function based on `baseName`.
FuncOp createNewTopLevelFn(ModuleOp moduleOp, std::string &baseName) {
std::string newName = baseName;
unsigned counter = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, newName)) {
newName = baseName + "_" + std::to_string(++counter);
newName = llvm::join_items("_", baseName, std::to_string(++counter));
}

OpBuilder builder(moduleOp.getContext());
Expand All @@ -1781,6 +1784,9 @@ class SCFToCalyxPass : public SCFToCalyxBase<SCFToCalyxPass> {
return nullptr;
}

/// Insert a call from the newly created top-level function/`caller` to the
/// old top-level function/`callee`; and create `memref.alloc`s inside the new
/// top-level function for arguments with `memref` types.
void insertCallFromNewTopLevel(OpBuilder &builder, FuncOp caller,
FuncOp callee) {
if (caller.getBody().empty()) {
Expand Down Expand Up @@ -1808,8 +1814,11 @@ class SCFToCalyxPass : public SCFToCalyxBase<SCFToCalyxPass> {
memRefArgs);
}

LogicalResult createOptNewTopLevelFcn(ModuleOp moduleOp,
std::string &topLevelFunction) {
/// Conditionally creates an optional new top-level function; and inserts a
/// call from the new top-level function to the old top-level function if we
/// did create one
LogicalResult createOptNewTopLevelFn(ModuleOp moduleOp,
std::string &topLevelFunction) {
auto hasMemrefArguments = [](FuncOp func) {
return std::any_of(
func.getArguments().begin(), func.getArguments().end(),
Expand All @@ -1830,7 +1839,7 @@ class SCFToCalyxPass : public SCFToCalyxBase<SCFToCalyxPass> {

std::string oldName = topLevelFunction;
if (hasMemrefArgsInTopLevel) {
auto newTopLevelFunc = createNewTopLevelFcn(moduleOp, topLevelFunction);
auto newTopLevelFunc = createNewTopLevelFn(moduleOp, topLevelFunction);

OpBuilder builder(moduleOp.getContext());
Operation *oldTopLevelFuncOp =
Expand Down

0 comments on commit d581ee4

Please sign in to comment.