Skip to content

Commit

Permalink
[scf-to-calyx] Support for function call (#5965)
Browse files Browse the repository at this point in the history
  • Loading branch information
linuxlonelyeagle authored Sep 19, 2023
1 parent 05d84c1 commit 89f8246
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 5 deletions.
11 changes: 11 additions & 0 deletions include/circt/Dialect/Calyx/CalyxHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Support/LLVM.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

#include <memory>

Expand All @@ -34,6 +35,16 @@ hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
size_t value);

/// A helper function to create calyx.instance operation.
calyx::InstanceOp createInstance(Location loc, OpBuilder &builder,
ComponentOp component,
SmallVectorImpl<Type> &resultTypes,
StringRef instanceName,
StringRef componentName);

/// A helper function to get the instance name.
std::string getInstanceName(mlir::func::CallOp callOp);

// Returns whether this operation is a leaf node in the Calyx control.
// TODO(github.com/llvm/circt/issues/1679): Add Invoke.
bool isControlLeafNode(Operation *op);
Expand Down
20 changes: 20 additions & 0 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ class ComponentLoweringStateInterface {
/// the original function maps to.
unsigned getFuncOpResultMapping(unsigned funcReturnIdx);

/// The instance is obtained from the name of the callee.
InstanceOp getInstance(StringRef calleeName);

/// Put the name of the callee and the instance of the call into map.
void addInstance(StringRef calleeName, InstanceOp instanceOp);

/// Return the group which evaluates the value v. Optionally, caller may
/// specify the expected type of the group.
template <typename TGroupOp = calyx::GroupInterface>
Expand Down Expand Up @@ -452,6 +458,9 @@ class ComponentLoweringStateInterface {
/// A mapping between the source funcOp result indices and the corresponding
/// output port indices of this componentOp.
DenseMap<unsigned, unsigned> funcOpResultMapping;

/// A mapping between the callee and the instance.
llvm::StringMap<calyx::InstanceOp> instanceMap;
};

/// An interface for conversion passes that lower Calyx programs. This handles
Expand Down Expand Up @@ -734,6 +743,17 @@ class BuildReturnRegs : public calyx::FuncOpPartialLoweringPattern {
PatternRewriter &rewriter) const override;
};

/// Builds instance for the calyx.invoke and calyx.group in order to initialize
/// the instance.
class BuildCallInstance : public calyx::FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
PatternRewriter &rewriter) const override;
ComponentOp getCallComponent(mlir::func::CallOp callOp) const;
};

} // namespace calyx
} // namespace circt

Expand Down
60 changes: 56 additions & 4 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,16 @@ struct ForScheduleable {
uint64_t bound;
};

struct CallScheduleable {
/// Instance for invoking.
calyx::InstanceOp instanceOp;
// CallOp for getting the arguments.
func::CallOp callOp;
};

/// A variant of types representing scheduleable operations.
using Scheduleable =
std::variant<calyx::GroupOp, WhileScheduleable, ForScheduleable>;
using Scheduleable = std::variant<calyx::GroupOp, WhileScheduleable,
ForScheduleable, CallScheduleable>;

class WhileLoopLoweringStateInterface
: calyx::LoopLoweringStateInterface<ScfWhileOp> {
Expand Down Expand Up @@ -210,7 +217,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
SelectOp, IndexCastOp>(
SelectOp, IndexCastOp, CallOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
.template Case<FuncOp, scf::ConditionOp>([&](auto) {
/// Skip: these special cases will be handled separately.
Expand Down Expand Up @@ -261,6 +268,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp) const;

/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
/// source operation TSrcOp.
Expand Down Expand Up @@ -899,6 +907,29 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CallOp callOp) const {
std::string instanceName = calyx::getInstanceName(callOp);
calyx::InstanceOp instanceOp =
getState<ComponentLoweringState>().getInstance(instanceName);
SmallVector<Value, 4> outputPorts;
auto portInfos = instanceOp.getReferencedComponent().getPortInfo();
for (auto [idx, portInfo] : enumerate(portInfos)) {
if (portInfo.direction == calyx::Direction::Output)
outputPorts.push_back(instanceOp.getResult(idx));
}

// Replacing a CallOp results in the out port of the instance.
for (auto [idx, result] : llvm::enumerate(callOp.getResults()))
rewriter.replaceAllUsesWith(result, outputPorts[idx]);

// CallScheduleanle requires an instance, while CallOp can be used to get the
// input ports.
getState<ComponentLoweringState>().addBlockScheduleable(
callOp.getOperation()->getBlock(), CallScheduleable{instanceOp, callOp});
return success();
}

/// Inlines Calyx ExecuteRegionOp operations within their parent blocks.
/// An execution region op (ERO) is inlined by:
/// i : add a sink basic block for all yield operations inside the
Expand Down Expand Up @@ -1027,6 +1058,9 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
auto compOp = rewriter.create<calyx::ComponentOp>(
funcOp.getLoc(), rewriter.getStringAttr(funcOp.getSymName()), ports);

std::string funcName = "func_" + funcOp.getSymName().str();
rewriter.updateRootInPlace(funcOp, [&]() { funcOp.setSymName(funcName); });

/// Mark this component as the toplevel.
compOp->setAttr("toplevel", rewriter.getUnitAttr());

Expand Down Expand Up @@ -1313,6 +1347,21 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
forLatchGroup.getName());
if (res.failed())
return res;
} else if (auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
auto instanceOp = callSchedPtr->instanceOp;
OpBuilder::InsertionGuard g(rewriter);
auto callBody = rewriter.create<calyx::SeqOp>(instanceOp.getLoc());
rewriter.setInsertionPointToStart(callBody.getBodyBlock());
std::string initGroupName = "init_" + instanceOp.getSymName().str();
rewriter.create<calyx::EnableOp>(instanceOp.getLoc(), initGroupName);
SmallVector<Value, 4> instancePorts;
auto inputPorts = callSchedPtr->callOp.getOperands();
llvm::copy(instanceOp.getResults().take_front(inputPorts.size()),
std::back_inserter(instancePorts));
rewriter.create<calyx::InvokeOp>(
instanceOp.getLoc(), instanceOp.getSymName(), instancePorts,
inputPorts, ArrayAttr::get(rewriter.getContext(), {}),
ArrayAttr::get(rewriter.getContext(), {}));
} else
llvm_unreachable("Unknown scheduleable");
}
Expand Down Expand Up @@ -1571,7 +1620,7 @@ class SCFToCalyxPass : public SCFToCalyxBase<SCFToCalyxPass> {
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
ExtSIOp>();
ExtSIOp, CallOp>();

RewritePatternSet legalizePatterns(&getContext());
legalizePatterns.add<DummyPattern>(&getContext());
Expand Down Expand Up @@ -1678,6 +1727,9 @@ void SCFToCalyxPass::runOnOperation() {
addOncePattern<calyx::BuildBasicBlockRegs>(loweringPatterns, patternState,
funcMap, *loweringState);

addOncePattern<calyx::BuildCallInstance>(loweringPatterns, patternState,
funcMap, *loweringState);

/// This pattern creates registers for the function return values.
addOncePattern<calyx::BuildReturnRegs>(loweringPatterns, patternState,
funcMap, *loweringState);
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Calyx/Transforms/CalyxHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
APInt(width, value, /*unsigned=*/true));
}

calyx::InstanceOp createInstance(Location loc, OpBuilder &builder,
ComponentOp component,
SmallVectorImpl<Type> &resultTypes,
StringRef instanceName,
StringRef componentName) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToStart(component.getBodyBlock());
return builder.create<InstanceOp>(loc, resultTypes, instanceName,
componentName);
}

std::string getInstanceName(mlir::func::CallOp callOp) {
SmallVector<StringRef, 2> strVet = {callOp.getCallee(), "instance"};
return llvm::join(strVet, /*separator=*/"_");
}

bool isControlLeafNode(Operation *op) { return isa<calyx::EnableOp>(op); }

DictionaryAttr getMandatoryPortAttr(MLIRContext *ctx, StringRef name) {
Expand Down
70 changes: 69 additions & 1 deletion lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,15 @@ unsigned ComponentLoweringStateInterface::getFuncOpResultMapping(
return it->second;
}

InstanceOp ComponentLoweringStateInterface::getInstance(StringRef calleeName) {
return instanceMap[calleeName];
}

void ComponentLoweringStateInterface::addInstance(StringRef calleeName,
InstanceOp instanceOp) {
instanceMap[calleeName] = instanceOp;
}

//===----------------------------------------------------------------------===//
// CalyxLoweringState
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -633,7 +642,8 @@ void InlineCombGroups::recurseInlineCombGroups(
isa<calyx::RegisterOp, calyx::MemoryOp, calyx::SeqMemoryOp,
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
calyx::RemUPipeLibOp, mlir::scf::WhileOp>(src.getDefiningOp()))
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp>(
src.getDefiningOp()))
continue;

auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(
Expand Down Expand Up @@ -744,5 +754,63 @@ BuildReturnRegs::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
return success();
}

//===----------------------------------------------------------------------===//
// BuildCallInstance
//===----------------------------------------------------------------------===//

LogicalResult
BuildCallInstance::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
PatternRewriter &rewriter) const {
funcOp.walk([&](mlir::func::CallOp callOp) {
ComponentOp componentOp = getCallComponent(callOp);
SmallVector<Type, 8> resultTypes;
for (auto type : componentOp.getArgumentTypes())
resultTypes.push_back(type);
for (auto type : componentOp.getResultTypes())
resultTypes.push_back(type);
std::string instanceName = getInstanceName(callOp);

// Determines if an instance needs to be created. If the same function was
// called by CallOp before, it doesn't need to be created, if not, the
// instance is created.
if (!getState().getInstance(instanceName)) {
InstanceOp instanceOp =
createInstance(callOp.getLoc(), rewriter, getComponent(), resultTypes,
instanceName, componentOp.getName());
getState().addInstance(instanceName, instanceOp);
hw::ConstantOp constantOp =
createConstant(callOp.getLoc(), rewriter, getComponent(), 1, 1);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(
getComponent().getWiresOp().getBodyBlock());

// Creates the group that initializes the instance.
calyx::GroupOp groupOp = rewriter.create<calyx::GroupOp>(
callOp.getLoc(), "init_" + instanceName);
rewriter.setInsertionPointToStart(groupOp.getBodyBlock());
auto portInfos = instanceOp.getReferencedComponent().getPortInfo();
auto results = instanceOp.getResults();
for (const auto &[portInfo, result] : llvm::zip(portInfos, results)) {
if (portInfo.hasAttribute("go") || portInfo.hasAttribute("reset"))
rewriter.create<calyx::AssignOp>(callOp.getLoc(), result, constantOp);
else if (portInfo.hasAttribute("done"))
rewriter.create<calyx::GroupDoneOp>(callOp.getLoc(), result);
}
}
WalkResult::advance();
});
return success();
}

ComponentOp
BuildCallInstance::getCallComponent(mlir::func::CallOp callOp) const {
std::string callee = "func_" + callOp.getCallee().str();
for (auto [funcOp, componentOp] : functionMapping) {
if (funcOp.getSymName() == callee)
return componentOp;
}
return nullptr;
}

} // namespace calyx
} // namespace circt
97 changes: 97 additions & 0 deletions test/Conversion/SCFToCalyx/convert_func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// RUN: circt-opt %s --lower-scf-to-calyx="top-level-function=main" -canonicalize -split-input-file | FileCheck %s

// CHECK-LABEL: calyx.enable @init_func_instance
// CHECK: calyx.invoke @func_instance(%[[VAL_0:.*]] = %[[VAL_1:.*]]) -> (i32)

module {
func.func @func(%0 : i32) -> i32 {
return %0 : i32
}

func.func @main() -> i32 {
%0 = arith.constant 0 : i32
%1 = func.call @func(%0) : (i32) -> i32
func.return %1 : i32
}
}

// -----

// CHECK-LABEL: calyx.enable @init_func_instance
// CHECK: calyx.invoke @func_instance(%[[VAL_0:.*]] = %[[VAL_1:.*]]) -> (i32)

module {
func.func @func(%0 : i32) -> i32 {
return %0 : i32
}

func.func @main() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c64 = arith.constant 64 : index
%0 = memref.alloc() : memref<64xi32>
%1 = memref.alloc() : memref<64xi32>
scf.while(%arg0 = %c0) : (index) -> (index) {
%cond = arith.cmpi ult, %arg0, %c64 : index
scf.condition(%cond) %arg0 : index
} do {
^bb0(%arg1: index):
%v = memref.load %0[%arg1] : memref<64xi32>
%c = func.call @func(%v) : (i32) -> i32
memref.store %c, %1[%arg1] : memref<64xi32>
%inc = arith.addi %arg1, %c1 : index
scf.yield %inc : index
}
return
}
}

// -----

// CHECK-LABEL: calyx.enable @init_fun_instance
// CHECK: calyx.invoke @fun_instance(%[[VAL_0:.*]] = %[[VAL_1:.*]]) -> (i32)

module {
func.func @fun(%0 : i32) -> i32 {
return %0 : i32
}

func.func @main() {
%alloca = memref.alloca() : memref<40xi32>
%c0 = arith.constant 0 : index
%c40 = arith.constant 40 : index
%c1 = arith.constant 1 : index
scf.for %arg0 = %c0 to %c40 step %c1 {
%0 = memref.load %alloca[%arg0] : memref<40xi32>
%1 = func.call @fun(%0) : (i32) -> i32
memref.store %1, %alloca[%arg0] : memref<40xi32>
}
return
}
}

// -----

// CHECK-LABEL: calyx.enable @init_func_instance
// CHECK: calyx.invoke @func_instance(%[[VAL_0:.*]] = %[[VAL_1:.*]]) -> (i32)
// CHECK: calyx.enable @init_func_instance
// CHECK: calyx.invoke @func_instance(%[[VAL_2:.*]] = %[[VAL_3:.*]]) -> (i32)

module {
func.func @func(%0 : i32) -> i32 {
return %0 : i32
}

func.func @main(%a0 : i32, %a1 : i32, %a2 : i32) -> i32 {
%0 = arith.addi %a0, %a1 : i32
%1 = arith.addi %0, %a1 : i32
%b = arith.cmpi uge, %1, %a2 : i32
cf.cond_br %b, ^bb1, ^bb2
^bb1:
%ret0 = func.call @func(%0) : (i32) -> i32
return %ret0 : i32
^bb2:
%ret1 = func.call @func(%1) : (i32) -> i32
return %ret1 : i32
}
}

0 comments on commit 89f8246

Please sign in to comment.