Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sim] Flatten format string concatenations in canonicalizer #7316

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion include/circt/Dialect/Sim/SimOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,34 @@ def FormatStringConcatOp : SimOp<"fmt.concat", [Pure]> {
let hasVerifier = true;

let assemblyFormat = "` ` `(` $inputs `)` attr-dict";
}

let extraClassDeclaration = [{
/// Returns true iff all of the input strings are primitive
/// (i.e. non-concatenated) tokens or block arguments.
bool isFlat() {
return llvm::none_of(getInputs(), [](Value operand) {
return !!operand.getDefiningOp<circt::sim::FormatStringConcatOp>();
});
};

/// Attempts to flatten this operation's input strings as much as possible.
///
/// The flattened values are pushed into the passed vector.
/// If the concatenation is acyclic, the function will return 'success'
/// and all the flattened values are guaranteed to _not_ be the result of
/// a format string concatenation.
/// If a cycle is encountered, the function will return 'failure'.
/// On encountering a cycle, the result of the concat operation
/// forming the cycle is pushed into the list of flattened values
/// and flattening continues without recursing into the cycle.
LogicalResult getFlattenedInputs(llvm::SmallVectorImpl<Value> &flatOperands);
}];

let builders = [
OpBuilder<(ins "mlir::ValueRange":$inputs), [{
return build($_builder, $_state, circt::sim::FormatStringType::get($_builder.getContext()), inputs);
}]>
];
}

#endif // CIRCT_DIALECT_SIM_SIMOPS_TD
79 changes: 74 additions & 5 deletions lib/Dialect/Sim/SimOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Dialect/HW/ModuleImplementation.h"
#include "circt/Dialect/SV/SVOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionImplementation.h"

Expand Down Expand Up @@ -190,8 +192,12 @@ static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef<StringRef> lits) {
OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
if (getNumOperands() == 0)
return StringAttr::get(getContext(), "");
if (getNumOperands() == 1)
if (getNumOperands() == 1) {
// Don't fold to our own result to avoid an infinte loop.
if (getResult() == getOperand(0))
return {};
return getOperand(0);
}

// Fold if all operands are literals.
SmallVector<StringRef> lits;
Expand All @@ -204,6 +210,49 @@ OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
return concatLiterals(getContext(), lits);
}

LogicalResult FormatStringConcatOp::getFlattenedInputs(
llvm::SmallVectorImpl<Value> &flatOperands) {
llvm::SmallMapVector<FormatStringConcatOp, unsigned, 4> concatStack;
bool isCyclic = false;

// Perform a DFS on this operation's concatenated operands,
// collect the leaf format string tokens.
concatStack.insert({*this, 0});
while (!concatStack.empty()) {
auto &top = concatStack.back();
auto currentConcat = top.first;
unsigned operandIndex = top.second;

// Iterate over concatenated operands
while (operandIndex < currentConcat.getNumOperands()) {
auto currentOperand = currentConcat.getOperand(operandIndex);

if (auto nextConcat =
currentOperand.getDefiningOp<FormatStringConcatOp>()) {
// Concat of a concat
if (!concatStack.contains(nextConcat)) {
// Save the next operand index to visit on the
// stack and put the new concat on top.
top.second = operandIndex + 1;
concatStack.insert({nextConcat, 0});
break;
}
// Cyclic concatenation encountered. Don't recurse.
isCyclic = true;
}

flatOperands.push_back(currentOperand);
operandIndex++;
}

// Pop the concat off of the stack if we have visited all operands.
if (operandIndex >= currentConcat.getNumOperands())
concatStack.pop_back();
}

return success(!isCyclic);
}

LogicalResult FormatStringConcatOp::verify() {
if (llvm::any_of(getOperands(),
[&](Value operand) { return operand == getResult(); }))
Expand All @@ -213,19 +262,39 @@ LogicalResult FormatStringConcatOp::verify() {

LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
PatternRewriter &rewriter) {
if (op.getNumOperands() < 2)
return failure(); // Should be handled by the folder

auto fmtStrType = FormatStringType::get(op.getContext());

// Check if we can flatten concats of concats
bool hasBeenFlattened = false;
SmallVector<Value, 0> flatOperands;
if (!op.isFlat()) {
// Get a new, flattened list of operands
flatOperands.reserve(op.getNumOperands() + 4);
auto isAcyclic = op.getFlattenedInputs(flatOperands);

if (failed(isAcyclic)) {
// Infinite recursion, but we cannot fail compilation right here (can we?)
// so just emit a warning and bail out.
op.emitWarning("Cyclic concatenation detected.");
return failure();
}

hasBeenFlattened = true;
}
Comment on lines +271 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we even have to emit diagnostics and handled failure cases here. I think the canonicalization is intended to be purely opportunistic, so if we encounter a cycle we should be able to just silently not canonicalize and return failure. Or in case getFlattenedInputs creates a valid list of operands even if a cycle is present, we could just work with that and ignore the cycle altogether. Especially since the verifier catches this already 😃

Copy link
Contributor Author

@fzi-hielscher fzi-hielscher Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is actually how I had done this originally. I then threw it out for a potentially silly reason: If we flatten despite the presence of cycles, it becomes difficult to test as it is hard to predict (and potentially unstable for future LLVM bumps?) which op is going to fail the verification first. And annoyingly I could not find any way to tell the -verify-diagnostics framework that the error may be produced at different locations. Quite possibly I'm overthinking this (I usually do). But in the end I decided to just keep it as simple as possible, given that:

  • A cyclic concatenation would almost certainly be indicative of a bug in the frontend.
  • It is pretty much guaranteed to cause a failure in the subsequent lowering pass.

I kept the warning in case something goes wrong and the cycle causes an infinite loop, so it doesn't just freeze quietly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense! Very cool stuff in any case 😎 🥳!


if (!hasBeenFlattened && op.getNumOperands() < 2)
return failure(); // Should be handled by the folder

// Check if there are adjacent literals we can merge or empty literals to
// remove
SmallVector<StringRef> litSequence;
SmallVector<Value> newOperands;
newOperands.reserve(op.getNumOperands());
FormatLitOp prevLitOp;

for (auto operand : op.getOperands()) {
auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands();
for (auto operand : oldOperands) {
if (auto litOp = operand.getDefiningOp<FormatLitOp>()) {
if (!litOp.getLiteral().empty()) {
prevLitOp = litOp;
Expand Down Expand Up @@ -263,7 +332,7 @@ LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
}
}

if (newOperands.size() == op.getNumOperands())
if (!hasBeenFlattened && newOperands.size() == op.getNumOperands())
return failure(); // Nothing changed

if (newOperands.empty())
Expand Down
48 changes: 48 additions & 0 deletions test/Dialect/Sim/format-strings.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,51 @@ hw.module @constant_fold3(in %zeroWitdh: i0, out res: !sim.fstring) {
%cat = sim.fmt.concat (%foo, %clf, %ccr, %null, %foo, %null, %cext)
hw.output %cat : !sim.fstring
}


// CHECK-LABEL: hw.module @flatten_concat1
// CHECK-DAG: %[[LH:.+]] = sim.fmt.lit "Hex: "
// CHECK-DAG: %[[LD:.+]] = sim.fmt.lit "Dec: "
// CHECK-DAG: %[[LB:.+]] = sim.fmt.lit "Bin: "
// CHECK-DAG: %[[FH:.+]] = sim.fmt.hex %val : i8
// CHECK-DAG: %[[FD:.+]] = sim.fmt.dec %val : i8
// CHECK-DAG: %[[FB:.+]] = sim.fmt.bin %val : i8
// CHECK-DAG: %[[CAT:.+]] = sim.fmt.concat (%[[LB]], %[[FB]], %[[LD]], %[[FD]], %[[LH]], %[[FH]])
// CHECK: hw.output %[[CAT]] : !sim.fstring

hw.module @flatten_concat1(in %val : i8, out res: !sim.fstring) {
%binLit = sim.fmt.lit "Bin: "
%binVal = sim.fmt.bin %val : i8
%binCat = sim.fmt.concat (%binLit, %binVal)

%decLit = sim.fmt.lit "Dec: "
%decVal = sim.fmt.dec %val : i8
%decCat = sim.fmt.concat (%decLit, %decVal, %nocat)

%nocat = sim.fmt.concat ()

%hexLit = sim.fmt.lit "Hex: "
%hexVal = sim.fmt.hex %val : i8
%hexCat = sim.fmt.concat (%hexLit, %hexVal)

%catcat = sim.fmt.concat (%binCat, %nocat, %decCat, %nocat, %hexCat)
hw.output %catcat : !sim.fstring
}

// CHECK-LABEL: hw.module @flatten_concat2
// CHECK-DAG: %[[F:.+]] = sim.fmt.lit "Foo"
// CHECK-DAG: %[[FF:.+]] = sim.fmt.lit "FooFoo"
// CHECK-DAG: %[[CHR:.+]] = sim.fmt.char %val : i8
// CHECK-DAG: %[[CAT:.+]] = sim.fmt.concat (%[[F]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[F]])
// CHECK: hw.output %[[CAT]] : !sim.fstring

hw.module @flatten_concat2(in %val : i8, out res: !sim.fstring) {
%foo = sim.fmt.lit "Foo"
%char = sim.fmt.char %val : i8

%c = sim.fmt.concat (%foo, %char, %foo)
%cc = sim.fmt.concat (%c, %c)
%ccccc = sim.fmt.concat (%cc, %c, %cc)

hw.output %ccccc : !sim.fstring
}
21 changes: 19 additions & 2 deletions test/Dialect/Sim/sim-errors.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
// RUN: circt-opt %s --split-input-file --verify-diagnostics
// RUN: circt-opt %s --split-input-file --verify-diagnostics --canonicalize

hw.module @fmt_infinite_concat() {
hw.module @fmt_infinite_concat_verify() {
%lp = sim.fmt.lit ", {"
%rp = sim.fmt.lit "}"
// expected-error @below {{op is infinitely recursive.}}
%ordinal = sim.fmt.concat (%ordinal, %lp, %ordinal, %rp)
}
// -----

hw.module @fmt_infinite_concat_canonicalize(in %val : i8, out res: !sim.fstring) {
%c = sim.fmt.char %val : i8
%0 = sim.fmt.lit "Here we go round the"
%1 = sim.fmt.lit "prickly pear"
// expected-warning @below {{Cyclic concatenation detected.}}
%2 = sim.fmt.concat (%1, %c, %4)
// expected-warning @below {{Cyclic concatenation detected.}}
%3 = sim.fmt.concat (%1, %c, %1, %c, %2, %c)
// expected-warning @below {{Cyclic concatenation detected.}}
%4 = sim.fmt.concat (%0, %c, %3)
%5 = sim.fmt.lit "At five o'clock in the morning."
// expected-warning @below {{Cyclic concatenation detected.}}
%cat = sim.fmt.concat (%4, %c, %5)
hw.output %cat : !sim.fstring
}
Loading