Skip to content

Commit

Permalink
[ARC] Keep just one parameter if it's given multiple times (#7284)
Browse files Browse the repository at this point in the history
  • Loading branch information
elhewaty authored Jul 9, 2024
1 parent dd93f06 commit ff02e7a
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 19 deletions.
85 changes: 77 additions & 8 deletions lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ struct MergeVectorizeOps : public OpRewritePattern<VectorizeOp> {
PatternRewriter &rewriter) const final;
};

struct KeepOneVecOp : public OpRewritePattern<VectorizeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(VectorizeOp op,
PatternRewriter &rewriter) const final;
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -285,6 +291,17 @@ LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp,
return failure();
}

LogicalResult updateInputOperands(VectorizeOp &vecOp,
const SmallVector<Value> &newOperands) {
// Set the new inputOperandSegments value
unsigned groupSize = vecOp.getResults().size();
unsigned numOfGroups = newOperands.size() / groupSize;
SmallVector<int32_t> newAttr(numOfGroups, groupSize);
vecOp.setInputOperandSegments(newAttr);
vecOp.getOperation()->setOperands(ValueRange(newOperands));
return success();
}

//===----------------------------------------------------------------------===//
// Canonicalization pattern implementations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -643,12 +660,7 @@ MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
if (!canBeMerged)
return failure();

// Set the new inputOperandSegments value
unsigned groupSize = vecOp.getResults().size();
unsigned numOfGroups = newOperands.size() / groupSize;
SmallVector<int32_t> newAttr(numOfGroups, groupSize);
vecOp.setInputOperandSegments(newAttr);
vecOp.getOperation()->setOperands(ValueRange(newOperands));
(void)updateInputOperands(vecOp, newOperands);

// Erase dead VectorizeOps
for (auto deadOp : vecOpsToRemove)
Expand All @@ -657,6 +669,63 @@ MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
return success();
}

namespace llvm {
static unsigned hashValue(const SmallVector<Value> &inputs) {
unsigned hash = hash_value(inputs.size());
for (auto input : inputs)
hash = hash_combine(hash, input);
return hash;
}

template <>
struct DenseMapInfo<SmallVector<Value>> {
static inline SmallVector<Value> getEmptyKey() {
return SmallVector<Value>();
}

static inline SmallVector<Value> getTombstoneKey() {
return SmallVector<Value>();
}

static unsigned getHashValue(const SmallVector<Value> &inputs) {
return hashValue(inputs);
}

static bool isEqual(const SmallVector<Value> &lhs,
const SmallVector<Value> &rhs) {
return lhs == rhs;
}
};
} // namespace llvm

LogicalResult KeepOneVecOp::matchAndRewrite(VectorizeOp vecOp,
PatternRewriter &rewriter) const {
BitVector argsToRemove(vecOp.getInputs().size(), false);
DenseMap<SmallVector<Value>, unsigned> inExists;
auto &currentBlock = vecOp.getBody().front();
SmallVector<Value> newOperands;
unsigned shuffledBy = 0;
bool changed = false;
for (auto [argIdx, inputVec] : llvm::enumerate(vecOp.getInputs())) {
auto input = SmallVector<Value>(inputVec.begin(), inputVec.end());
if (auto in = inExists.find(input); in != inExists.end()) {
argsToRemove.set(argIdx);
rewriter.replaceAllUsesWith(currentBlock.getArgument(argIdx - shuffledBy),
currentBlock.getArgument(in->second));
currentBlock.eraseArgument(argIdx - shuffledBy);
++shuffledBy;
changed = true;
continue;
}
inExists[input] = argIdx;
newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
}

if (!changed)
return failure();
return updateInputOperands(vecOp, newOperands);
}

//===----------------------------------------------------------------------===//
// ArcCanonicalizerPass implementation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -704,8 +773,8 @@ void ArcCanonicalizerPass::runOnOperation() {
dialect->getCanonicalizationPatterns(patterns);
for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, &ctxt);
patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps>(
&getContext());
patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps,
KeepOneVecOp>(&getContext());

// Don't test for convergence since it is often not reached.
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
Expand Down
58 changes: 47 additions & 11 deletions test/Dialect/Arc/arc-canonicalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {

// CHECK-LABEL: hw.module @Test_2_in_1(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%o, %v, %q, %s), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT ^bb0(%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT [[AND:%.+]] = comb.and %arg2, %arg3 : i8
// CHECK-NEXT [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[ADD]], [[AND]]) : (i8, i8) -> i8
Expand Down Expand Up @@ -398,18 +398,14 @@ in %clock: !seq.clock) {
}

// CHECK-LABEL: hw.module @More_Than_One_Use(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock) {
// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[ADD]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t), ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2, [[VEC0]]#3) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8):
// CHECK-NEXT: [[AND:%.+]] = comb.and %arg0, %arg1 : i8
// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg2) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD]], %arg2 : i8
// CHECK-NEXT: [[RET:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], [[ADD]]) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[RET]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC1]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

Expand Down Expand Up @@ -475,3 +471,43 @@ in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC2]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

hw.module @Repeated_input(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %en: i1, in %clock: !seq.clock) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%b, %e, %h, %k) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.mul %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%4 = arc.state @FooMux(%en, %R, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}

// CHECK-LABEL: hw.module @Repeated_input(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %en : i1, in %clock : !seq.clock) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k) : (i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8):
// CHECK-NEXT: [[MUL:%.+]] = comb.mul %arg0, %arg0 : i8
// CHECK-NEXT: arc.vectorize.return [[MUL]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

hw.module @Repeated_input_1(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %en: i1, in %clock: !seq.clock) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%b, %e, %h, %k): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.add %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%4 = arc.state @FooMux(%en, %R, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}

// CHECK-LABEL: hw.module @Repeated_input_1(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %en : i1, in %clock : !seq.clock) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k) : (i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8):
// CHECK-NEXT: [[FALSE:%.+]] = hw.constant false
// CHECK-NEXT: [[EXTRACT:%.+]] = comb.extract %arg0 from 0 : (i8) -> i7
// CHECK-NEXT: [[CONCAT:%.+]] = comb.concat [[EXTRACT]], [[FALSE]] : i7, i1
// CHECK-NEXT: arc.vectorize.return [[CONCAT]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

0 comments on commit ff02e7a

Please sign in to comment.