diff --git a/include/circt/Dialect/Sim/SimOps.td b/include/circt/Dialect/Sim/SimOps.td index ab9e1ab8ed10..a23fb1829f84 100644 --- a/include/circt/Dialect/Sim/SimOps.td +++ b/include/circt/Dialect/Sim/SimOps.td @@ -297,7 +297,34 @@ def FormatStringConcatOp : SimOp<"fmt.concat", [Pure]> { let hasVerifier = true; let assemblyFormat = "` ` `(` $inputs `)` attr-dict"; -} + let extraClassDeclaration = [{ + /// Returns true iff all of the input strings are primitive + /// (i.e. non-concatenated) tokens or block arguments. + bool isFlat() { + return llvm::none_of(getInputs(), [](Value operand) { + return !!operand.getDefiningOp(); + }); + }; + + /// Attempts to flatten this operation's input strings as much as possible. + /// + /// The flattened values are pushed into the passed vector. + /// If the concatenation is acyclic, the function will return 'success' + /// and all the flattened values are guaranteed to _not_ be the result of + /// a format string concatenation. + /// If a cycle is encountered, the function will return 'failure'. + /// On encountering a cycle, the result of the concat operation + /// forming the cycle is pushed into the list of flattened values + /// and flattening continues without recursing into the cycle. + LogicalResult getFlattenedInputs(llvm::SmallVectorImpl &flatOperands); + }]; + + let builders = [ + OpBuilder<(ins "mlir::ValueRange":$inputs), [{ + return build($_builder, $_state, circt::sim::FormatStringType::get($_builder.getContext()), inputs); + }]> + ]; +} #endif // CIRCT_DIALECT_SIM_SIMOPS_TD diff --git a/lib/Dialect/Sim/SimOps.cpp b/lib/Dialect/Sim/SimOps.cpp index e2426b2d94fb..50772ce98f70 100644 --- a/lib/Dialect/Sim/SimOps.cpp +++ b/lib/Dialect/Sim/SimOps.cpp @@ -12,6 +12,8 @@ #include "circt/Dialect/Sim/SimOps.h" #include "circt/Dialect/HW/ModuleImplementation.h" +#include "circt/Dialect/SV/SVOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionImplementation.h" @@ -190,8 +192,12 @@ static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef lits) { OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) { if (getNumOperands() == 0) return StringAttr::get(getContext(), ""); - if (getNumOperands() == 1) + if (getNumOperands() == 1) { + // Don't fold to our own result to avoid an infinte loop. + if (getResult() == getOperand(0)) + return {}; return getOperand(0); + } // Fold if all operands are literals. SmallVector lits; @@ -204,6 +210,49 @@ OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) { return concatLiterals(getContext(), lits); } +LogicalResult FormatStringConcatOp::getFlattenedInputs( + llvm::SmallVectorImpl &flatOperands) { + llvm::SmallMapVector concatStack; + bool isCyclic = false; + + // Perform a DFS on this operation's concatenated operands, + // collect the leaf format string tokens. + concatStack.insert({*this, 0}); + while (!concatStack.empty()) { + auto &top = concatStack.back(); + auto currentConcat = top.first; + unsigned operandIndex = top.second; + + // Iterate over concatenated operands + while (operandIndex < currentConcat.getNumOperands()) { + auto currentOperand = currentConcat.getOperand(operandIndex); + + if (auto nextConcat = + currentOperand.getDefiningOp()) { + // Concat of a concat + if (!concatStack.contains(nextConcat)) { + // Save the next operand index to visit on the + // stack and put the new concat on top. + top.second = operandIndex + 1; + concatStack.insert({nextConcat, 0}); + break; + } + // Cyclic concatenation encountered. Don't recurse. + isCyclic = true; + } + + flatOperands.push_back(currentOperand); + operandIndex++; + } + + // Pop the concat off of the stack if we have visited all operands. + if (operandIndex >= currentConcat.getNumOperands()) + concatStack.pop_back(); + } + + return success(!isCyclic); +} + LogicalResult FormatStringConcatOp::verify() { if (llvm::any_of(getOperands(), [&](Value operand) { return operand == getResult(); })) @@ -213,11 +262,30 @@ LogicalResult FormatStringConcatOp::verify() { LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op, PatternRewriter &rewriter) { - if (op.getNumOperands() < 2) - return failure(); // Should be handled by the folder auto fmtStrType = FormatStringType::get(op.getContext()); + // Check if we can flatten concats of concats + bool hasBeenFlattened = false; + SmallVector flatOperands; + if (!op.isFlat()) { + // Get a new, flattened list of operands + flatOperands.reserve(op.getNumOperands() + 4); + auto isAcyclic = op.getFlattenedInputs(flatOperands); + + if (failed(isAcyclic)) { + // Infinite recursion, but we cannot fail compilation right here (can we?) + // so just emit a warning and bail out. + op.emitWarning("Cyclic concatenation detected."); + return failure(); + } + + hasBeenFlattened = true; + } + + if (!hasBeenFlattened && op.getNumOperands() < 2) + return failure(); // Should be handled by the folder + // Check if there are adjacent literals we can merge or empty literals to // remove SmallVector litSequence; @@ -225,7 +293,8 @@ LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op, newOperands.reserve(op.getNumOperands()); FormatLitOp prevLitOp; - for (auto operand : op.getOperands()) { + auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands(); + for (auto operand : oldOperands) { if (auto litOp = operand.getDefiningOp()) { if (!litOp.getLiteral().empty()) { prevLitOp = litOp; @@ -263,7 +332,7 @@ LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op, } } - if (newOperands.size() == op.getNumOperands()) + if (!hasBeenFlattened && newOperands.size() == op.getNumOperands()) return failure(); // Nothing changed if (newOperands.empty()) diff --git a/test/Dialect/Sim/format-strings.mlir b/test/Dialect/Sim/format-strings.mlir index b45b3aeedf5f..75fd8eb88072 100644 --- a/test/Dialect/Sim/format-strings.mlir +++ b/test/Dialect/Sim/format-strings.mlir @@ -117,3 +117,51 @@ hw.module @constant_fold3(in %zeroWitdh: i0, out res: !sim.fstring) { %cat = sim.fmt.concat (%foo, %clf, %ccr, %null, %foo, %null, %cext) hw.output %cat : !sim.fstring } + + +// CHECK-LABEL: hw.module @flatten_concat1 +// CHECK-DAG: %[[LH:.+]] = sim.fmt.lit "Hex: " +// CHECK-DAG: %[[LD:.+]] = sim.fmt.lit "Dec: " +// CHECK-DAG: %[[LB:.+]] = sim.fmt.lit "Bin: " +// CHECK-DAG: %[[FH:.+]] = sim.fmt.hex %val : i8 +// CHECK-DAG: %[[FD:.+]] = sim.fmt.dec %val : i8 +// CHECK-DAG: %[[FB:.+]] = sim.fmt.bin %val : i8 +// CHECK-DAG: %[[CAT:.+]] = sim.fmt.concat (%[[LB]], %[[FB]], %[[LD]], %[[FD]], %[[LH]], %[[FH]]) +// CHECK: hw.output %[[CAT]] : !sim.fstring + +hw.module @flatten_concat1(in %val : i8, out res: !sim.fstring) { + %binLit = sim.fmt.lit "Bin: " + %binVal = sim.fmt.bin %val : i8 + %binCat = sim.fmt.concat (%binLit, %binVal) + + %decLit = sim.fmt.lit "Dec: " + %decVal = sim.fmt.dec %val : i8 + %decCat = sim.fmt.concat (%decLit, %decVal, %nocat) + + %nocat = sim.fmt.concat () + + %hexLit = sim.fmt.lit "Hex: " + %hexVal = sim.fmt.hex %val : i8 + %hexCat = sim.fmt.concat (%hexLit, %hexVal) + + %catcat = sim.fmt.concat (%binCat, %nocat, %decCat, %nocat, %hexCat) + hw.output %catcat : !sim.fstring +} + +// CHECK-LABEL: hw.module @flatten_concat2 +// CHECK-DAG: %[[F:.+]] = sim.fmt.lit "Foo" +// CHECK-DAG: %[[FF:.+]] = sim.fmt.lit "FooFoo" +// CHECK-DAG: %[[CHR:.+]] = sim.fmt.char %val : i8 +// CHECK-DAG: %[[CAT:.+]] = sim.fmt.concat (%[[F]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[F]]) +// CHECK: hw.output %[[CAT]] : !sim.fstring + +hw.module @flatten_concat2(in %val : i8, out res: !sim.fstring) { + %foo = sim.fmt.lit "Foo" + %char = sim.fmt.char %val : i8 + + %c = sim.fmt.concat (%foo, %char, %foo) + %cc = sim.fmt.concat (%c, %c) + %ccccc = sim.fmt.concat (%cc, %c, %cc) + + hw.output %ccccc : !sim.fstring +} diff --git a/test/Dialect/Sim/sim-errors.mlir b/test/Dialect/Sim/sim-errors.mlir index e40c48fad65b..9707bf44ab1b 100644 --- a/test/Dialect/Sim/sim-errors.mlir +++ b/test/Dialect/Sim/sim-errors.mlir @@ -1,8 +1,25 @@ -// RUN: circt-opt %s --split-input-file --verify-diagnostics +// RUN: circt-opt %s --split-input-file --verify-diagnostics --canonicalize -hw.module @fmt_infinite_concat() { +hw.module @fmt_infinite_concat_verify() { %lp = sim.fmt.lit ", {" %rp = sim.fmt.lit "}" // expected-error @below {{op is infinitely recursive.}} %ordinal = sim.fmt.concat (%ordinal, %lp, %ordinal, %rp) } +// ----- + +hw.module @fmt_infinite_concat_canonicalize(in %val : i8, out res: !sim.fstring) { + %c = sim.fmt.char %val : i8 + %0 = sim.fmt.lit "Here we go round the" + %1 = sim.fmt.lit "prickly pear" + // expected-warning @below {{Cyclic concatenation detected.}} + %2 = sim.fmt.concat (%1, %c, %4) + // expected-warning @below {{Cyclic concatenation detected.}} + %3 = sim.fmt.concat (%1, %c, %1, %c, %2, %c) + // expected-warning @below {{Cyclic concatenation detected.}} + %4 = sim.fmt.concat (%0, %c, %3) + %5 = sim.fmt.lit "At five o'clock in the morning." + // expected-warning @below {{Cyclic concatenation detected.}} + %cat = sim.fmt.concat (%4, %c, %5) + hw.output %cat : !sim.fstring +}