diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index ca71867462f..3856cf045fe 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -340,6 +340,9 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, }; } + +/// This is the entry point to register reverse-mode custom derivatives programmatically. +/// A more detailed documentation is available in GradientUtils.h void EnzymeRegisterCallHandler(char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle) { @@ -363,6 +366,8 @@ void EnzymeRegisterCallHandler(char *Name, }; } +/// This is the entry point to register forward-mode custom derivatives programmatically. +/// A more detailed documentation is available in GradientUtils.h void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { auto &pair = customFwdCallHandlers[Name]; pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index d739babf42e..8e98aa3584f 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -79,6 +79,14 @@ extern llvm::StringMap &, llvm::CallInst *, GradientUtils &, llvm::Value *&, llvm::Value *&, llvm::Value *&)>, @@ -86,6 +94,66 @@ extern llvm::StringMap>> customCallHandlers; +/// The StringMap allows looking up a (forward-mode) custom rule based on the mangled name of the function. +/// The first argument is the IRBuilder, the third argument are gradientutils, both of which should be already +/// available in the frontend. The second argument is the CallInst, this should be for a function call which will +/// compute the forward-mode derivative, while taking into consideration which input arguments are active or const. +/// The function returns true, if the custom rule was applied, and false otherwise (e.g. because the combination of +/// activities is not yet supported). The last two arguments are ... +/// +/// Example: +/// define double @my_pow(double %x, double %y) { +/// %call = call double @llvm.pow(double %x, double %y) +/// ret double %call +/// } +/// +/// The custom rule for this function could be: +/// customCallHandlers["my_pow"] = [](llvm::IRBuilder<> &Builder, llvm::CallInst *CI, GradientUtils &gutils, llvm::Value *&dcall, llvm::Value *&normalReturn, llvm::Value *&shadowReturn) { +/// auto x = CI->getArgOperand(0); +/// auto y = CI->getArgOperand(1); +/// auto xprime = gutils.getNewFromOriginal(x); +/// auto yprime = gutils.getNewFromOriginal(y); +/// bool is_x_active = !gutils.isConstantValue(x); +/// bool is_y_active = !gutils.isConstantValue(y); +/// normalreturn = Builder.CreateCall(Intrinsic::pow, {x, y}); +/// if (is_x_active) { +/// auto ym1 = Builder.CreateFSub(y, ConstantFP::get(Type::getDoubleTy(CI->getContext()), 1.0)); +/// auto pow = Builder.CreateCall(Intrinsic::pow, {x, ym1}); +/// auto ypow = Builder.CreateFMul(y, pow); +/// shadowReturn = Builder.CreateFMul(xprime, ypow); +/// +/// // if y were inactive, this would be conceptually equivalent to generating +/// // define internal double @fwddiffetester(double %x, double %"x'", double %y) #1 { +/// // %0 = fsub fast double %y, 1.000000e+00 +/// // %1 = call fast double @llvm.pow.f64(double %x, double %0) +/// // %2 = fmul fast double %y, %1 +/// // %3 = fmul fast double %"x'", %2 +/// // ret double %3 +/// // } +/// } +/// if (is_y_active) { +/// auto pow = Builder.CreateCall(Intrinsic::pow, {x, y}); +/// auto log = Builder.CreateCall(Intrinsic::log, {x}); +/// auto logpow = Builder.CreateFMul(pow, log); +/// auto ylogpow = Builder.CreateFMul(yprime, logpow); +/// if (is_x_active) { +/// shadowReturn = Builder.CreateFAdd(ylogpow, shadowReturn); +/// } else { +/// shadowReturn = ylogpow; +/// } +/// +/// // if x was inactive, this would be conceptually equivalent to generating +/// // define internal double @fwddiffetester.1(double %x, double %y, double %"y'") #1 { +/// // %0 = call fast double @llvm.pow.f64(double %x, double %y) +/// // %1 = call fast double @llvm.log.f64(double %x) +/// // %2 = fmul fast double %0, %1 +/// // %3 = fmul fast double %"y'", %2 +/// // ret double %3 +/// // } +/// } +/// // We covered all 2x2 combinations, so always return true +/// return true; +/// } extern llvm::StringMap< std::function &, llvm::CallInst *, GradientUtils &, llvm::Value *&, llvm::Value *&)>>