Skip to content

Commit

Permalink
[RTG][Elaboration] Elaboration support for Bags (#7892)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Dec 9, 2024
1 parent e6e5dca commit de64a10
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 2 deletions.
164 changes: 162 additions & 2 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace rtg {
using namespace mlir;
using namespace circt;
using namespace circt::rtg;
using llvm::MapVector;

#define DEBUG_TYPE "rtg-elaboration"

Expand Down Expand Up @@ -83,7 +84,7 @@ namespace {
/// The abstract base class for elaborated values.
struct ElaboratorValue {
public:
enum class ValueKind { Attribute, Set };
enum class ValueKind { Attribute, Set, Bag };

ElaboratorValue(ValueKind kind) : kind(kind) {}
virtual ~ElaboratorValue() {}
Expand Down Expand Up @@ -197,6 +198,60 @@ class SetValue : public ElaboratorValue {
};
} // namespace

/// Holds an evaluated value of a `BagType`'d value.
class BagValue : public ElaboratorValue {
public:
BagValue(MapVector<ElaboratorValue *, uint64_t> &&bag, Type type)
: ElaboratorValue(ValueKind::Bag), bag(std::move(bag)), type(type),
cachedHash(llvm::hash_combine(
llvm::hash_combine_range(bag.begin(), bag.end()), type)) {}

// Implement LLVMs RTTI
static bool classof(const ElaboratorValue *val) {
return val->getKind() == ValueKind::Bag;
}

llvm::hash_code getHashValue() const override { return cachedHash; }

bool isEqual(const ElaboratorValue &other) const override {
auto *otherBag = dyn_cast<BagValue>(&other);
if (!otherBag)
return false;

if (cachedHash != otherBag->cachedHash)
return false;

return llvm::equal(bag, otherBag->bag) && type == otherBag->type;
}

#ifndef NDEBUG
void print(llvm::raw_ostream &os) const override {
os << "<bag {";
llvm::interleaveComma(bag, os,
[&](std::pair<ElaboratorValue *, uint64_t> el) {
el.first->print(os);
os << " -> " << el.second;
});
os << "} at " << this << ">";
}
#endif

const MapVector<ElaboratorValue *, uint64_t> &getBag() const { return bag; }

Type getType() const { return type; }

private:
// Stores the elaborated values of the bag.
const MapVector<ElaboratorValue *, uint64_t> bag;

// Store the type of the bag such that we can materialize this evaluated value
// also in the case where the bag is empty.
const Type type;

// Compute the hash only once at constructor time.
const llvm::hash_code cachedHash;
};

#ifndef NDEBUG
static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ElaboratorValue &value) {
Expand Down Expand Up @@ -255,7 +310,7 @@ class Materializer {
OpBuilder builder = builderIter->second;

return TypeSwitch<ElaboratorValue *, Value>(val)
.Case<AttributeValue, SetValue>(
.Case<AttributeValue, SetValue, BagValue>(
[&](auto val) { return visit(val, builder, loc, emitError); })
.Default([](auto val) {
assert(false && "all cases must be covered above");
Expand Down Expand Up @@ -313,13 +368,45 @@ class Materializer {
return res;
}

Value visit(BagValue *val, OpBuilder &builder, Location loc,
function_ref<InFlightDiagnostic()> emitError) {
SmallVector<Value> values, weights;
values.reserve(val->getBag().size());
weights.reserve(val->getBag().size());
for (auto [val, weight] : val->getBag()) {
auto materializedVal =
materialize(val, builder.getBlock(), loc, emitError);
if (!materializedVal)
return Value();

auto iter = integerValues.find({weight, builder.getBlock()});
Value materializedWeight;
if (iter != integerValues.end()) {
materializedWeight = iter->second;
} else {
materializedWeight = builder.create<arith::ConstantOp>(
loc, builder.getIndexAttr(weight));
integerValues[{weight, builder.getBlock()}] = materializedWeight;
}

values.push_back(materializedVal);
weights.push_back(materializedWeight);
}

auto res =
builder.create<BagCreateOp>(loc, val->getType(), values, weights);
materializedValues[{val, builder.getBlock()}] = res;
return res;
}

private:
/// Cache values we have already materialized to reuse them later. We start
/// with an insertion point at the start of the block and cache the (updated)
/// insertion point such that future materializations can also reuse previous
/// materializations without running into dominance issues (or requiring
/// additional checks to avoid them).
DenseMap<std::pair<ElaboratorValue *, Block *>, Value> materializedValues;
DenseMap<std::pair<uint64_t, Block *>, Value> integerValues;

/// Cache the builders to continue insertions at their current insertion point
/// for the reason stated above.
Expand Down Expand Up @@ -427,6 +514,79 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(BagCreateOp op, function_ref<void(Operation *)> addToWorklist) {
MapVector<ElaboratorValue *, uint64_t> bag;
for (auto [val, multiple] :
llvm::zip(op.getElements(), op.getMultiples())) {
auto *interpValue = state.at(val);
// If the multiple is not stored as an AttributeValue, the elaboration
// must have already failed earlier (since we don't have
// unevaluated/opaque values).
auto *interpMultiple = cast<AttributeValue>(state.at(multiple));
uint64_t m = cast<IntegerAttr>(interpMultiple->getAttr()).getInt();
bag[interpValue] += m;
}

internalizeResult<BagValue>(op.getBag(), std::move(bag), op.getType());
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(BagSelectRandomOp op, function_ref<void(Operation *)> addToWorklist) {
auto *bag = cast<BagValue>(state.at(op.getBag()));

SmallVector<std::pair<ElaboratorValue *, uint32_t>> prefixSum;
prefixSum.reserve(bag->getBag().size());
uint32_t accumulator = 0;
for (auto [val, weight] : bag->getBag()) {
accumulator += weight;
prefixSum.push_back({val, accumulator});
}

auto customRng = rng;
if (auto intAttr =
op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
customRng = std::mt19937(intAttr.getInt());
}

auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
auto *iter = llvm::upper_bound(
prefixSum, idx,
[](uint32_t a, const std::pair<ElaboratorValue *, uint32_t> &b) {
return a < b.second;
});
state[op.getResult()] = iter->first;
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(BagDifferenceOp op, function_ref<void(Operation *)> addToWorklist) {
auto *original = cast<BagValue>(state.at(op.getOriginal()));
auto *diff = cast<BagValue>(state.at(op.getDiff()));

MapVector<ElaboratorValue *, uint64_t> result;
for (const auto &el : original->getBag()) {
if (!diff->getBag().contains(el.first)) {
result.insert(el);
continue;
}

if (op.getInf())
continue;

auto toDiff = diff->getBag().lookup(el.first);
if (el.second <= toDiff)
continue;

result.insert({el.first, el.second - toDiff});
}

internalizeResult<BagValue>(op.getResult(), std::move(result),
op.getType());
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
dispatchOpVisitor(Operation *op,
function_ref<void(Operation *)> addToWorklist) {
Expand Down
24 changes: 24 additions & 0 deletions test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
func.func @dummy1(%arg0: i32, %arg1: i32, %arg2: !rtg.set<i32>) -> () {return}
func.func @dummy2(%arg0: i32) -> () {return}
func.func @dummy3(%arg0: i64) -> () {return}
func.func @dummy4(%arg0: i32, %arg1: i32, %arg2: !rtg.bag<i32>, %arg3: !rtg.bag<i32>) -> () {return}

// Test the set operations and passing a sequence to another one via argument
// CHECK-LABEL: rtg.test @setOperations
Expand All @@ -25,6 +26,29 @@ rtg.test @setOperations : !rtg.dict<> {
func.call @dummy1(%4, %5, %diff) : (i32, i32, !rtg.set<i32>) -> ()
}

// CHECK-LABEL: rtg.test @bagOperations
rtg.test @bagOperations : !rtg.dict<> {
// CHECK-NEXT: [[V0:%.+]] = arith.constant 2 : i32
// CHECK-NEXT: [[V1:%.+]] = arith.constant 8 : index
// CHECK-NEXT: [[V2:%.+]] = arith.constant 3 : i32
// CHECK-NEXT: [[V3:%.+]] = arith.constant 7 : index
// CHECK-NEXT: [[V4:%.+]] = rtg.bag_create ([[V1]] x [[V0]], [[V3]] x [[V2]]) : i32
// CHECK-NEXT: [[V5:%.+]] = rtg.bag_create ([[V1]] x [[V0]]) : i32
// CHECK-NEXT: func.call @dummy4([[V0]], [[V0]], [[V4]], [[V5]]) :
%multiple = arith.constant 8 : index
%one = arith.constant 1 : index
%0 = arith.constant 2 : i32
%1 = arith.constant 3 : i32
%bag = rtg.bag_create (%multiple x %0, %multiple x %1) : i32
%2 = rtg.bag_select_random %bag : !rtg.bag<i32> {rtg.elaboration_custom_seed = 3}
%new_bag = rtg.bag_create (%one x %2) : i32
%diff = rtg.bag_difference %bag, %new_bag : !rtg.bag<i32>
%3 = rtg.bag_select_random %diff : !rtg.bag<i32> {rtg.elaboration_custom_seed = 4}
%diff2 = rtg.bag_difference %bag, %new_bag inf : !rtg.bag<i32>
%4 = rtg.bag_select_random %diff2 : !rtg.bag<i32> {rtg.elaboration_custom_seed = 5}
func.call @dummy4(%3, %4, %diff, %diff2) : (i32, i32, !rtg.bag<i32>, !rtg.bag<i32>) -> ()
}

// CHECK-LABEL: @targetTest_target0
// CHECK: [[V0:%.+]] = arith.constant 0
// CHECK: func.call @dummy2([[V0]]) :
Expand Down

0 comments on commit de64a10

Please sign in to comment.