Skip to content

Commit

Permalink
experimental expr JIT evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Dec 9, 2024
1 parent 882cd35 commit 92ba931
Showing 1 changed file with 130 additions and 5 deletions.
135 changes: 130 additions & 5 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@

#include "llvm/Demangle/Demangle.h"

#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"

#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"

#include "llvm/Support/Casting.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <llvm/Support/JSON.h>

#include "llvm/Pass.h"

Expand Down Expand Up @@ -91,7 +96,10 @@ static cl::opt<int> HerbieTimeout("herbie-timeout", cl::init(120), cl::Hidden,
"candidate expressions."));
static cl::opt<std::string>
FPOptCachePath("fpopt-cache-path", cl::init(""), cl::Hidden,
cl::desc("Experimental: path to cache Herbie results"));
cl::desc("Path to cache Herbie results"));
static cl::opt<bool>
FPOptEnableJIT("fpopt-enable-jit", cl::init(true), cl::Hidden,
cl::desc("Experimental: Use JIT in candidate evaluation"));
static cl::opt<int>
HerbieNumPoints("herbie-num-pts", cl::init(1024), cl::Hidden,
cl::desc("Number of input points Herbie uses to evaluate "
Expand Down Expand Up @@ -2929,11 +2937,116 @@ class ApplicableFPCC {
}
};

void JITExpr(
const std::string &expr,
std::unordered_map<Value *, std::shared_ptr<FPNode>> &valueToNodeMap,
std::unordered_map<std::string, Value *> &symbolToValueMap,
const FastMathFlags &FMF, ArrayRef<FPNode *> outputs,
const SmallMapVector<Value *, double, 4> &inputValues,
SmallVectorImpl<double> &results) {
using namespace llvm::orc;
// llvm::errs() << "JIT'ting " << expr << "\n";

SmallSet<std::string, 8> argStrSet;
getUniqueArgs(expr, argStrSet);

size_t NumInputs = argStrSet.size();
size_t NumOutputs = 1;

auto parsedNode = parseHerbieExpr(expr, valueToNodeMap, symbolToValueMap);

auto TSCtx = std::make_unique<LLVMContext>();
LLVMContext &Ctx = *TSCtx;
std::unique_ptr<Module> UniqueM = std::make_unique<Module>("jit_module", Ctx);

Type *Int64Ty = Type::getInt64Ty(Ctx);
Type *DoubleTy = Type::getDoubleTy(Ctx);
Type *DoublePtrTy = Type::getDoublePtrTy(Ctx);

FunctionType *FT =
FunctionType::get(Type::getVoidTy(Ctx),
{Int64Ty, Int64Ty, DoublePtrTy, DoublePtrTy}, false);
Function *JitFunc = Function::Create(FT, Function::ExternalLinkage,
"tempExpr", UniqueM.get());

auto ArgIt = JitFunc->arg_begin();
Value *NInVal = &*ArgIt++;
NInVal->setName("numInputs");
Value *NOutVal = &*ArgIt++;
NOutVal->setName("numOutputs");
Value *InArr = &*ArgIt++;
InArr->setName("inputs");
Value *OutArr = &*ArgIt++;
OutArr->setName("outputs");

BasicBlock *entry = BasicBlock::Create(Ctx, "entry", JitFunc);
Instruction *ReturnInst = ReturnInst::Create(Ctx, entry);
IRBuilder<> builder(ReturnInst);
builder.setFastMathFlags(FMF);

std::vector<std::string> argNames(argStrSet.begin(), argStrSet.end());
std::unordered_map<std::string, unsigned> argIndexMap;
for (unsigned i = 0; i < argNames.size(); i++)
argIndexMap[argNames[i]] = i;

// Load input values from the input array
ValueToValueMapTy VMap;
for (auto &kv : symbolToValueMap) {
const std::string &sym = kv.first;
Value *origVal = kv.second;
if (argIndexMap.count(sym)) {
Value *Index = builder.getInt64(argIndexMap[sym]);
Value *Ptr = builder.CreateGEP(DoubleTy, InArr, Index, sym + "_ptr");
Value *Loaded = builder.CreateLoad(DoubleTy, Ptr, sym);
VMap[origVal] = Loaded;
}
}

// Materialize the expression
Value *Expr = parsedNode->getLLValue(builder, &VMap);

// Store the result in the output array
Value *OutPtr =
builder.CreateGEP(DoubleTy, OutArr, {builder.getInt64(0)}, "out_ptr");
builder.CreateStore(Expr, OutPtr);

// JitFunc->print(llvm::errs());

llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto J = cantFail(LLJITBuilder().create());
ThreadSafeModule TSM(std::move(UniqueM), std::move(TSCtx));
cantFail(J->addIRModule(std::move(TSM)));

auto Sym = cantFail(J->lookup("tempExpr"));

using JitFuncTy = void (*)(int64_t, int64_t, const double *, double *);
auto *JitFuncPtr = Sym.toPtr<JitFuncTy>();

std::vector<double> inputVals(NumInputs, 0.0);
for (unsigned i = 0; i < argNames.size(); i++) {
Value *argVal = symbolToValueMap[argNames[i]];
auto it = inputValues.find(argVal);
assert(it != inputValues.end() &&
"Missing input value for a required argument!");
inputVals[i] = it->second;
}

std::vector<double> outputVals(NumOutputs, 0.0);

JitFuncPtr((int64_t)NumInputs, (int64_t)NumOutputs, inputVals.data(),
outputVals.data());

results.clear();
results.append(outputVals.begin(), outputVals.end());
}

void setUnifiedAccuracyCost(
ApplicableOutput &AO,
std::unordered_map<Value *, std::shared_ptr<FPNode>> &valueToNodeMap,
std::unordered_map<std::string, Value *> &symbolToValueMap) {

SmallVector<SmallMapVector<Value *, double, 4>, 4> sampledPoints;
getSampledPoints(AO.component->inputs.getArrayRef(), valueToNodeMap,
symbolToValueMap, sampledPoints);
Expand All @@ -2951,7 +3064,13 @@ void setUnifiedAccuracyCost(
// llvm::errs() << "DEBUG AO gold value: " << goldVal << "\n";
goldVals[pair.index()] = goldVal;

getFPValues(outputs, pair.value(), results);
if (FPOptEnableJIT) {
JITExpr(AO.expr, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags(), outputs,
pair.value(), results);
} else {
getFPValues(outputs, pair.value(), results);
}
double realVal = results[0];
// llvm::errs() << "DEBUG AO real value: " << realVal << "\n";

Expand Down Expand Up @@ -2993,7 +3112,13 @@ void setUnifiedAccuracyCost(

ArrayRef<FPNode *> outputs = {parsedNode.get()};
SmallVector<double, 1> results;
getFPValues(outputs, pair.value(), results);
if (FPOptEnableJIT) {
JITExpr(expr, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags(), outputs,
pair.value(), results);
} else {
getFPValues(outputs, pair.value(), results);
}
double realVal = results[0];

// llvm::errs() << "Real value: " << realVal << "\n";
Expand Down

0 comments on commit 92ba931

Please sign in to comment.