From de64a1051d0c55184e16847c6633eeab26822f49 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Mon, 9 Dec 2024 13:47:32 +0000 Subject: [PATCH] [RTG][Elaboration] Elaboration support for Bags (#7892) --- .../RTG/Transforms/ElaborationPass.cpp | 164 +++++++++++++++++- test/Dialect/RTG/Transform/elaboration.mlir | 24 +++ 2 files changed, 186 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 16a30d59cacf..9227356728f8 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -33,6 +33,7 @@ namespace rtg { using namespace mlir; using namespace circt; using namespace circt::rtg; +using llvm::MapVector; #define DEBUG_TYPE "rtg-elaboration" @@ -83,7 +84,7 @@ namespace { /// The abstract base class for elaborated values. struct ElaboratorValue { public: - enum class ValueKind { Attribute, Set }; + enum class ValueKind { Attribute, Set, Bag }; ElaboratorValue(ValueKind kind) : kind(kind) {} virtual ~ElaboratorValue() {} @@ -197,6 +198,60 @@ class SetValue : public ElaboratorValue { }; } // namespace +/// Holds an evaluated value of a `BagType`'d value. +class BagValue : public ElaboratorValue { +public: + BagValue(MapVector &&bag, Type type) + : ElaboratorValue(ValueKind::Bag), bag(std::move(bag)), type(type), + cachedHash(llvm::hash_combine( + llvm::hash_combine_range(bag.begin(), bag.end()), type)) {} + + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue *val) { + return val->getKind() == ValueKind::Bag; + } + + llvm::hash_code getHashValue() const override { return cachedHash; } + + bool isEqual(const ElaboratorValue &other) const override { + auto *otherBag = dyn_cast(&other); + if (!otherBag) + return false; + + if (cachedHash != otherBag->cachedHash) + return false; + + return llvm::equal(bag, otherBag->bag) && type == otherBag->type; + } + +#ifndef NDEBUG + void print(llvm::raw_ostream &os) const override { + os << " el) { + el.first->print(os); + os << " -> " << el.second; + }); + os << "} at " << this << ">"; + } +#endif + + const MapVector &getBag() const { return bag; } + + Type getType() const { return type; } + +private: + // Stores the elaborated values of the bag. + const MapVector bag; + + // Store the type of the bag such that we can materialize this evaluated value + // also in the case where the bag is empty. + const Type type; + + // Compute the hash only once at constructor time. + const llvm::hash_code cachedHash; +}; + #ifndef NDEBUG static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ElaboratorValue &value) { @@ -255,7 +310,7 @@ class Materializer { OpBuilder builder = builderIter->second; return TypeSwitch(val) - .Case( + .Case( [&](auto val) { return visit(val, builder, loc, emitError); }) .Default([](auto val) { assert(false && "all cases must be covered above"); @@ -313,6 +368,37 @@ class Materializer { return res; } + Value visit(BagValue *val, OpBuilder &builder, Location loc, + function_ref emitError) { + SmallVector values, weights; + values.reserve(val->getBag().size()); + weights.reserve(val->getBag().size()); + for (auto [val, weight] : val->getBag()) { + auto materializedVal = + materialize(val, builder.getBlock(), loc, emitError); + if (!materializedVal) + return Value(); + + auto iter = integerValues.find({weight, builder.getBlock()}); + Value materializedWeight; + if (iter != integerValues.end()) { + materializedWeight = iter->second; + } else { + materializedWeight = builder.create( + loc, builder.getIndexAttr(weight)); + integerValues[{weight, builder.getBlock()}] = materializedWeight; + } + + values.push_back(materializedVal); + weights.push_back(materializedWeight); + } + + auto res = + builder.create(loc, val->getType(), values, weights); + materializedValues[{val, builder.getBlock()}] = res; + return res; + } + private: /// Cache values we have already materialized to reuse them later. We start /// with an insertion point at the start of the block and cache the (updated) @@ -320,6 +406,7 @@ class Materializer { /// materializations without running into dominance issues (or requiring /// additional checks to avoid them). DenseMap, Value> materializedValues; + DenseMap, Value> integerValues; /// Cache the builders to continue insertions at their current insertion point /// for the reason stated above. @@ -427,6 +514,79 @@ class Elaborator : public RTGOpVisitor, return DeletionKind::Delete; } + FailureOr + visitOp(BagCreateOp op, function_ref addToWorklist) { + MapVector bag; + for (auto [val, multiple] : + llvm::zip(op.getElements(), op.getMultiples())) { + auto *interpValue = state.at(val); + // If the multiple is not stored as an AttributeValue, the elaboration + // must have already failed earlier (since we don't have + // unevaluated/opaque values). + auto *interpMultiple = cast(state.at(multiple)); + uint64_t m = cast(interpMultiple->getAttr()).getInt(); + bag[interpValue] += m; + } + + internalizeResult(op.getBag(), std::move(bag), op.getType()); + return DeletionKind::Delete; + } + + FailureOr + visitOp(BagSelectRandomOp op, function_ref addToWorklist) { + auto *bag = cast(state.at(op.getBag())); + + SmallVector> prefixSum; + prefixSum.reserve(bag->getBag().size()); + uint32_t accumulator = 0; + for (auto [val, weight] : bag->getBag()) { + accumulator += weight; + prefixSum.push_back({val, accumulator}); + } + + auto customRng = rng; + if (auto intAttr = + op->getAttrOfType("rtg.elaboration_custom_seed")) { + customRng = std::mt19937(intAttr.getInt()); + } + + auto idx = getUniformlyInRange(customRng, 0, accumulator - 1); + auto *iter = llvm::upper_bound( + prefixSum, idx, + [](uint32_t a, const std::pair &b) { + return a < b.second; + }); + state[op.getResult()] = iter->first; + return DeletionKind::Delete; + } + + FailureOr + visitOp(BagDifferenceOp op, function_ref addToWorklist) { + auto *original = cast(state.at(op.getOriginal())); + auto *diff = cast(state.at(op.getDiff())); + + MapVector result; + for (const auto &el : original->getBag()) { + if (!diff->getBag().contains(el.first)) { + result.insert(el); + continue; + } + + if (op.getInf()) + continue; + + auto toDiff = diff->getBag().lookup(el.first); + if (el.second <= toDiff) + continue; + + result.insert({el.first, el.second - toDiff}); + } + + internalizeResult(op.getResult(), std::move(result), + op.getType()); + return DeletionKind::Delete; + } + FailureOr dispatchOpVisitor(Operation *op, function_ref addToWorklist) { diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index 69000aa3714c..34b12d641357 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -3,6 +3,7 @@ func.func @dummy1(%arg0: i32, %arg1: i32, %arg2: !rtg.set) -> () {return} func.func @dummy2(%arg0: i32) -> () {return} func.func @dummy3(%arg0: i64) -> () {return} +func.func @dummy4(%arg0: i32, %arg1: i32, %arg2: !rtg.bag, %arg3: !rtg.bag) -> () {return} // Test the set operations and passing a sequence to another one via argument // CHECK-LABEL: rtg.test @setOperations @@ -25,6 +26,29 @@ rtg.test @setOperations : !rtg.dict<> { func.call @dummy1(%4, %5, %diff) : (i32, i32, !rtg.set) -> () } +// CHECK-LABEL: rtg.test @bagOperations +rtg.test @bagOperations : !rtg.dict<> { + // CHECK-NEXT: [[V0:%.+]] = arith.constant 2 : i32 + // CHECK-NEXT: [[V1:%.+]] = arith.constant 8 : index + // CHECK-NEXT: [[V2:%.+]] = arith.constant 3 : i32 + // CHECK-NEXT: [[V3:%.+]] = arith.constant 7 : index + // CHECK-NEXT: [[V4:%.+]] = rtg.bag_create ([[V1]] x [[V0]], [[V3]] x [[V2]]) : i32 + // CHECK-NEXT: [[V5:%.+]] = rtg.bag_create ([[V1]] x [[V0]]) : i32 + // CHECK-NEXT: func.call @dummy4([[V0]], [[V0]], [[V4]], [[V5]]) : + %multiple = arith.constant 8 : index + %one = arith.constant 1 : index + %0 = arith.constant 2 : i32 + %1 = arith.constant 3 : i32 + %bag = rtg.bag_create (%multiple x %0, %multiple x %1) : i32 + %2 = rtg.bag_select_random %bag : !rtg.bag {rtg.elaboration_custom_seed = 3} + %new_bag = rtg.bag_create (%one x %2) : i32 + %diff = rtg.bag_difference %bag, %new_bag : !rtg.bag + %3 = rtg.bag_select_random %diff : !rtg.bag {rtg.elaboration_custom_seed = 4} + %diff2 = rtg.bag_difference %bag, %new_bag inf : !rtg.bag + %4 = rtg.bag_select_random %diff2 : !rtg.bag {rtg.elaboration_custom_seed = 5} + func.call @dummy4(%3, %4, %diff, %diff2) : (i32, i32, !rtg.bag, !rtg.bag) -> () +} + // CHECK-LABEL: @targetTest_target0 // CHECK: [[V0:%.+]] = arith.constant 0 // CHECK: func.call @dummy2([[V0]]) :