Skip to content

Commit

Permalink
[Calyx] bit cast floating point/signed integer types to signless inte…
Browse files Browse the repository at this point in the history
…ger as bit vector (#7977)

* buildLibraryOp cast floating point and signed integer types to signless integer types

* enable build return arguments with signed integer type by making convIndexType to a more generic normalize type function


Co-authored-by: Chris Gyurgyik <[email protected]>

---------

Co-authored-by: Chris Gyurgyik <[email protected]>
  • Loading branch information
jiahanxie353 and cgyurgyik authored Dec 13, 2024
1 parent 387291f commit 5ef9ea9
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 19 deletions.
21 changes: 18 additions & 3 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ bool noStoresToMemory(Value memoryReference);
// Get the index'th output port of compOp.
Value getComponentOutput(calyx::ComponentOp compOp, unsigned outPortIdx);

// If the provided type is an index type, converts it to i32, else, returns the
// unmodified type.
Type convIndexType(OpBuilder &builder, Type type);
// If the provided type is an index type, converts it to i32; else if the
// provided is an integer or floating point, bitcasts it to a signless integer
// type; otherwise, returns the unmodified type.
Type normalizeType(OpBuilder &builder, Type type);

// Creates a new calyx::CombGroupOp or calyx::GroupOp group within compOp.
template <typename TGroup>
Expand Down Expand Up @@ -838,6 +839,20 @@ struct PredicateInfo {

PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred);

/// Performs a bit cast from a non-signless integer type value, such as a
/// floating point value, to a signless integer type. Calyx treats everything as
/// bit vectors, and leaves their interpretation to the respective operation
/// using it. In CIRCT Calyx, we use signless `IntegerType` to represent a bit
/// vector.
template <typename T>
Type toBitVector(T type) {
if (!type.isSignlessInteger()) {
unsigned bitWidth = cast<T>(type).getIntOrFloatBitWidth();
return IntegerType::get(type.getContext(), bitWidth);
}
return type;
};

} // namespace calyx
} // namespace circt

Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,8 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
IndexCastOp op) const {
Type sourceType = calyx::convIndexType(rewriter, op.getOperand().getType());
Type targetType = calyx::convIndexType(rewriter, op.getResult().getType());
Type sourceType = calyx::normalizeType(rewriter, op.getOperand().getType());
Type targetType = calyx::normalizeType(rewriter, op.getResult().getType());
unsigned targetBits = targetType.getIntOrFloatBitWidth();
unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
LogicalResult res = success();
Expand Down Expand Up @@ -793,7 +793,7 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
funcOpArgRewrites[arg.value()] = inPorts.size();
inPorts.push_back(calyx::PortInfo{
rewriter.getStringAttr(inName),
calyx::convIndexType(rewriter, arg.value().getType()),
calyx::normalizeType(rewriter, arg.value().getType()),
calyx::Direction::Input,
DictionaryAttr::get(rewriter.getContext(), {})});
}
Expand All @@ -802,7 +802,7 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
funcOpResultMapping[res.index()] = outPorts.size();
outPorts.push_back(calyx::PortInfo{
rewriter.getStringAttr("out" + std::to_string(res.index())),
calyx::convIndexType(rewriter, res.value()), calyx::Direction::Output,
calyx::normalizeType(rewriter, res.value()), calyx::Direction::Output,
DictionaryAttr::get(rewriter.getContext(), {})});
}

Expand Down
14 changes: 8 additions & 6 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,10 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op,
TypeRange srcTypes, TypeRange dstTypes) const {
SmallVector<Type> types;
llvm::append_range(types, srcTypes);
llvm::append_range(types, dstTypes);
for (Type srcType : srcTypes)
types.push_back(calyx::toBitVector(srcType));
for (Type dstType : dstTypes)
types.push_back(calyx::toBitVector(dstType));

auto calyxOp =
getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
Expand Down Expand Up @@ -1387,8 +1389,8 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
IndexCastOp op) const {
Type sourceType = calyx::convIndexType(rewriter, op.getOperand().getType());
Type targetType = calyx::convIndexType(rewriter, op.getResult().getType());
Type sourceType = calyx::normalizeType(rewriter, op.getOperand().getType());
Type targetType = calyx::normalizeType(rewriter, op.getResult().getType());
unsigned targetBits = targetType.getIntOrFloatBitWidth();
unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
LogicalResult res = success();
Expand Down Expand Up @@ -1577,7 +1579,7 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
funcOpArgRewrites[arg.value()] = inPorts.size();
inPorts.push_back(calyx::PortInfo{
rewriter.getStringAttr(inName),
calyx::convIndexType(rewriter, arg.value().getType()),
calyx::normalizeType(rewriter, arg.value().getType()),
calyx::Direction::Input,
DictionaryAttr::get(rewriter.getContext(), {})});
}
Expand All @@ -1593,7 +1595,7 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {

outPorts.push_back(calyx::PortInfo{
rewriter.getStringAttr(resName),
calyx::convIndexType(rewriter, res.value()), calyx::Direction::Output,
calyx::normalizeType(rewriter, res.value()), calyx::Direction::Output,
DictionaryAttr::get(rewriter.getContext(), {})});
}

Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ Value getComponentOutput(calyx::ComponentOp compOp, unsigned outPortIdx) {
return compOp.getArgument(index);
}

Type convIndexType(OpBuilder &builder, Type type) {
Type normalizeType(OpBuilder &builder, Type type) {
if (type.isIndex())
return builder.getI32Type();
if (type.isIntOrFloat() && !type.isInteger())
return builder.getIntegerType(type.getIntOrFloatBitWidth());
if (type.isIntOrFloat())
return toBitVector(type);
return type;
}

Expand Down Expand Up @@ -522,7 +522,7 @@ ConvertIndexTypes::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
PatternRewriter &rewriter) const {
funcOp.walk([&](Block *block) {
for (Value arg : block->getArguments())
arg.setType(calyx::convIndexType(rewriter, arg.getType()));
arg.setType(calyx::normalizeType(rewriter, arg.getType()));
});

funcOp.walk([&](Operation *op) {
Expand All @@ -531,7 +531,7 @@ ConvertIndexTypes::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
if (!resType.isIndex())
continue;

result.setType(calyx::convIndexType(rewriter, resType));
result.setType(calyx::normalizeType(rewriter, resType));
auto constant = dyn_cast<mlir::arith::ConstantOp>(op);
if (!constant)
continue;
Expand Down Expand Up @@ -766,7 +766,7 @@ BuildReturnRegs::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
PatternRewriter &rewriter) const {

for (auto argType : enumerate(funcOp.getResultTypes())) {
auto convArgType = calyx::convIndexType(rewriter, argType.value());
auto convArgType = calyx::normalizeType(rewriter, argType.value());
assert((isa<IntegerType>(convArgType) || isa<FloatType>(convArgType)) &&
"unsupported return type");
std::string name = "ret_arg" + std::to_string(argType.index());
Expand Down
92 changes: 92 additions & 0 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,95 @@ module {
}
}

// -----

// Test lowering SelectOp and CmpFOp with floating point operands

// CHECK: %std_mux_1.cond, %std_mux_1.tru, %std_mux_1.fal, %std_mux_1.out = calyx.std_mux @std_mux_1 : i1, i64, i64, i64
// CHECK-DAG: %unordered_port_1_reg.in, %unordered_port_1_reg.write_en, %unordered_port_1_reg.clk, %unordered_port_1_reg.reset, %unordered_port_1_reg.out, %unordered_port_1_reg.done = calyx.register @unordered_port_1_reg : i1, i1, i1, i1, i1, i1
// CHECK-DAG: %cmpf_1_reg.in, %cmpf_1_reg.write_en, %cmpf_1_reg.clk, %cmpf_1_reg.reset, %cmpf_1_reg.out, %cmpf_1_reg.done = calyx.register @cmpf_1_reg : i1, i1, i1, i1, i1, i1
// CHECK-DAG: %std_compareFN_1.clk, %std_compareFN_1.reset, %std_compareFN_1.go, %std_compareFN_1.left, %std_compareFN_1.right, %std_compareFN_1.signaling, %std_compareFN_1.lt, %std_compareFN_1.eq, %std_compareFN_1.gt, %std_compareFN_1.unordered, %std_compareFN_1.exceptionalFlags, %std_compareFN_1.done = calyx.ieee754.compare @std_compareFN_1 : i1, i1, i1, i64, i64, i1, i1, i1, i1, i1, i5, i1
// CHECK-DAG: %std_mux_0.cond, %std_mux_0.tru, %std_mux_0.fal, %std_mux_0.out = calyx.std_mux @std_mux_0 : i1, i64, i64, i64
// CHECK-DAG: %std_and_0.left, %std_and_0.right, %std_and_0.out = calyx.std_and @std_and_0 : i1, i1, i1
// CHECK-DAG: %std_or_0.left, %std_or_0.right, %std_or_0.out = calyx.std_or @std_or_0 : i1, i1, i1
// CHECK-DAG: %unordered_port_0_reg.in, %unordered_port_0_reg.write_en, %unordered_port_0_reg.clk, %unordered_port_0_reg.reset, %unordered_port_0_reg.out, %unordered_port_0_reg.done = calyx.register @unordered_port_0_reg : i1, i1, i1, i1, i1, i1
// CHECK-DAG: %compare_port_0_reg.in, %compare_port_0_reg.write_en, %compare_port_0_reg.clk, %compare_port_0_reg.reset, %compare_port_0_reg.out, %compare_port_0_reg.done = calyx.register @compare_port_0_reg : i1, i1, i1, i1, i1, i1
// CHECK-DAG: %cmpf_0_reg.in, %cmpf_0_reg.write_en, %cmpf_0_reg.clk, %cmpf_0_reg.reset, %cmpf_0_reg.out, %cmpf_0_reg.done = calyx.register @cmpf_0_reg : i1, i1, i1, i1, i1, i1
// CHECK-DAG: %std_compareFN_0.clk, %std_compareFN_0.reset, %std_compareFN_0.go, %std_compareFN_0.left, %std_compareFN_0.right, %std_compareFN_0.signaling, %std_compareFN_0.lt, %std_compareFN_0.eq, %std_compareFN_0.gt, %std_compareFN_0.unordered, %std_compareFN_0.exceptionalFlags, %std_compareFN_0.done = calyx.ieee754.compare @std_compareFN_0 : i1, i1, i1, i64, i64, i1, i1, i1, i1, i1, i5, i1
// CHECK: calyx.wires {
// CHECK: calyx.group @bb0_0 {
// CHECK-DAG: calyx.assign %std_compareFN_0.left = %in0 : i64
// CHECK-DAG: calyx.assign %std_compareFN_0.right = %in1 : i64
// CHECK-DAG: calyx.assign %std_compareFN_0.signaling = %true : i1
// CHECK-DAG: calyx.assign %compare_port_0_reg.write_en = %std_compareFN_0.done : i1
// CHECK-DAG: calyx.assign %compare_port_0_reg.in = %std_compareFN_0.gt : i1
// CHECK-DAG: calyx.assign %unordered_port_0_reg.write_en = %std_compareFN_0.done : i1
// CHECK-DAG: calyx.assign %unordered_port_0_reg.in = %std_compareFN_0.unordered : i1
// CHECK-DAG: calyx.assign %std_or_0.left = %compare_port_0_reg.out : i1
// CHECK-DAG: calyx.assign %std_or_0.right = %unordered_port_0_reg.out : i1
// CHECK-DAG: calyx.assign %std_and_0.left = %compare_port_0_reg.done : i1
// CHECK-DAG: calyx.assign %std_and_0.right = %unordered_port_0_reg.done : i1
// CHECK-DAG: calyx.assign %cmpf_0_reg.in = %std_or_0.out : i1
// CHECK-DAG: calyx.assign %cmpf_0_reg.write_en = %std_and_0.out : i1
// CHECK-DAG: %0 = comb.xor %std_compareFN_0.done, %true : i1
// CHECK-DAG: calyx.assign %std_compareFN_0.go = %0 ? %true : i1
// CHECK-DAG: calyx.group_done %cmpf_0_reg.done : i1
// CHECK-DAG: }
// CHECK: calyx.group @bb0_2 {
// CHECK-DAG: calyx.assign %std_compareFN_1.left = %in1 : i64
// CHECK-DAG: calyx.assign %std_compareFN_1.right = %in1 : i64
// CHECK-DAG: calyx.assign %std_compareFN_1.signaling = %false : i1
// CHECK-DAG: calyx.assign %unordered_port_1_reg.write_en = %std_compareFN_1.done : i1
// CHECK-DAG: calyx.assign %unordered_port_1_reg.in = %std_compareFN_1.unordered : i1
// CHECK-DAG: calyx.assign %cmpf_1_reg.in = %unordered_port_1_reg.out : i1
// CHECK-DAG: calyx.assign %cmpf_1_reg.write_en = %unordered_port_1_reg.out : i1
// CHECK-DAG: %0 = comb.xor %std_compareFN_1.done, %true : i1
// CHECK-DAG: calyx.assign %std_compareFN_1.go = %0 ? %true : i1
// CHECK-DAG: calyx.group_done %cmpf_1_reg.done : i1
// CHECK-DAG: }
// CHECK: calyx.group @ret_assign_0 {
// CHECK-DAG: calyx.assign %ret_arg0_reg.in = %std_mux_1.out : i64
// CHECK-DAG: calyx.assign %ret_arg0_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %std_mux_1.cond = %cmpf_1_reg.out : i1
// CHECK-DAG: calyx.assign %std_mux_1.tru = %in1 : i64
// CHECK-DAG: calyx.assign %std_mux_1.fal = %std_mux_0.out : i64
// CHECK-DAG: calyx.assign %std_mux_0.cond = %cmpf_0_reg.out : i1
// CHECK-DAG: calyx.assign %std_mux_0.tru = %in0 : i64
// CHECK-DAG: calyx.assign %std_mux_0.fal = %in1 : i64
// CHECK-DAG: calyx.group_done %ret_arg0_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: }

module {
func.func @main(%arg0: f64, %arg1: f64) -> f64 {
%0 = arith.cmpf ugt, %arg0, %arg1 : f64
%1 = arith.select %0, %arg0, %arg1 : f64
%2 = arith.cmpf uno, %arg1, %arg1 : f64
%3 = arith.select %2, %arg1, %1 : f64
return %3 : f64
}
}

// Test SelectOp with signed integer type to signless integer type

// -----

// CHECK: %std_mux_0.cond, %std_mux_0.tru, %std_mux_0.fal, %std_mux_0.out = calyx.std_mux @std_mux_0 : i1, i32, i32, i32
// CHECK-DAG: %ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
// CHECK: calyx.wires {
// CHECK: calyx.group @ret_assign_0 {
// CHECK-DAG: calyx.assign %ret_arg0_reg.in = %std_mux_0.out : i32
// CHECK-DAG: calyx.assign %ret_arg0_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %std_mux_0.cond = %in2 : i1
// CHECK-DAG: calyx.assign %std_mux_0.tru = %in0 : i32
// CHECK-DAG: calyx.assign %std_mux_0.fal = %in1 : i32
// CHECK-DAG: calyx.group_done %ret_arg0_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: }

module {
func.func @main(%true : si32, %false: si32, %cond: i1) -> si32 {
%res = "arith.select" (%cond, %true, %false) : (i1, si32, si32) -> si32
return %res : si32
}
}

0 comments on commit 5ef9ea9

Please sign in to comment.