From 3f3ee9429ed0662345a19dbef2f5f2615035e730 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 5 Dec 2024 22:37:33 -0500 Subject: [PATCH 1/3] prototype --- enzyme/Enzyme/CApi.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index ca71867462f..d6d7b1c6628 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -340,6 +340,21 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, }; } + +/// This is the main entry point to register a custom derivative for language frontends. +/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. +/// The main reason is that these rules can handle non-default activity cases, e.g. +/// a function call where a pointer or a float scalar is marked as const. +/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. +/// +/// This Function will only handle using ReverseMode AD (either split or combined). +/// As a high-level example, assume we want to register a custom derivative for a vector resize function. +/// We pass the mangled name as first argument. +/// resizing is a simple enough to be handled in the fwd pass, so we just pass a nullptr as RevHandle. +/// For the forward pass, we will need to resize the shadow of the input (if duplicated), so the CI +/// is a CallInst of resize on the shadow argument. The IRBuilder B, the shadow argument, and gutils +/// should all be provided available in the frontend. The last three arguments ... +/// void EnzymeRegisterCallHandler(char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle) { @@ -363,6 +378,13 @@ void EnzymeRegisterCallHandler(char *Name, }; } +/// This is the main entry point to register a custom derivative for language frontends. +/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. +/// The main reason is that these rules can handle non-default activity cases, e.g. +/// a function call where a pointer or a float scalar is marked as const. +/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. +/// +/// This Function will only handle using ForwardMode AD. void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { auto &pair = customFwdCallHandlers[Name]; pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, From 2dd80bf807dccc3c752be304c326f02684949316 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 6 Dec 2024 00:52:39 -0500 Subject: [PATCH 2/3] adressing feedback --- enzyme/Enzyme/CApi.cpp | 27 ++++++--------------------- enzyme/Enzyme/GradientUtils.h | 7 +++++++ 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index d6d7b1c6628..7192fb07862 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -341,20 +341,11 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, } -/// This is the main entry point to register a custom derivative for language frontends. -/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. -/// The main reason is that these rules can handle non-default activity cases, e.g. -/// a function call where a pointer or a float scalar is marked as const. -/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. -/// -/// This Function will only handle using ReverseMode AD (either split or combined). -/// As a high-level example, assume we want to register a custom derivative for a vector resize function. -/// We pass the mangled name as first argument. -/// resizing is a simple enough to be handled in the fwd pass, so we just pass a nullptr as RevHandle. -/// For the forward pass, we will need to resize the shadow of the input (if duplicated), so the CI -/// is a CallInst of resize on the shadow argument. The IRBuilder B, the shadow argument, and gutils -/// should all be provided available in the frontend. The last three arguments ... -/// +/// This Function will only handle ReverseMode AD (either split or combined). +/// As a high-level example, assume we want to register a custom derivative for `pow(x, y)`. +/// We pass the mangled name of pow as first argument. +/// The IRBuilder B, the shadow argument, and gutils should all be available in the frontend. +/// The last three arguments ... void EnzymeRegisterCallHandler(char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle) { @@ -378,13 +369,7 @@ void EnzymeRegisterCallHandler(char *Name, }; } -/// This is the main entry point to register a custom derivative for language frontends. -/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. -/// The main reason is that these rules can handle non-default activity cases, e.g. -/// a function call where a pointer or a float scalar is marked as const. -/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. -/// -/// This Function will only handle using ForwardMode AD. +/// This Function will only handle ForwardMode AD. void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { auto &pair = customFwdCallHandlers[Name]; pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index d739babf42e..2f9f1390fb4 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -79,6 +79,13 @@ extern llvm::StringMap &, llvm::CallInst *, GradientUtils &, llvm::Value *&, llvm::Value *&, llvm::Value *&)>, From c7545e8fe02acfa98c297eaff4ab7e8bfd8ecef8 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 7 Dec 2024 02:55:40 -0500 Subject: [PATCH 3/3] add fwd mode example --- enzyme/Enzyme/CApi.cpp | 10 ++--- enzyme/Enzyme/GradientUtils.h | 71 ++++++++++++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 7192fb07862..3856cf045fe 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -341,11 +341,8 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, } -/// This Function will only handle ReverseMode AD (either split or combined). -/// As a high-level example, assume we want to register a custom derivative for `pow(x, y)`. -/// We pass the mangled name of pow as first argument. -/// The IRBuilder B, the shadow argument, and gutils should all be available in the frontend. -/// The last three arguments ... +/// This is the entry point to register reverse-mode custom derivatives programmatically. +/// A more detailed documentation is available in GradientUtils.h void EnzymeRegisterCallHandler(char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle) { @@ -369,7 +366,8 @@ void EnzymeRegisterCallHandler(char *Name, }; } -/// This Function will only handle ForwardMode AD. +/// This is the entry point to register forward-mode custom derivatives programmatically. +/// A more detailed documentation is available in GradientUtils.h void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { auto &pair = customFwdCallHandlers[Name]; pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 2f9f1390fb4..8e98aa3584f 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -80,12 +80,13 @@ extern llvm::StringMap &, llvm::CallInst *, GradientUtils &, llvm::Value *&, llvm::Value *&, llvm::Value *&)>, @@ -93,6 +94,66 @@ extern llvm::StringMap>> customCallHandlers; +/// The StringMap allows looking up a (forward-mode) custom rule based on the mangled name of the function. +/// The first argument is the IRBuilder, the third argument are gradientutils, both of which should be already +/// available in the frontend. The second argument is the CallInst, this should be for a function call which will +/// compute the forward-mode derivative, while taking into consideration which input arguments are active or const. +/// The function returns true, if the custom rule was applied, and false otherwise (e.g. because the combination of +/// activities is not yet supported). The last two arguments are ... +/// +/// Example: +/// define double @my_pow(double %x, double %y) { +/// %call = call double @llvm.pow(double %x, double %y) +/// ret double %call +/// } +/// +/// The custom rule for this function could be: +/// customCallHandlers["my_pow"] = [](llvm::IRBuilder<> &Builder, llvm::CallInst *CI, GradientUtils &gutils, llvm::Value *&dcall, llvm::Value *&normalReturn, llvm::Value *&shadowReturn) { +/// auto x = CI->getArgOperand(0); +/// auto y = CI->getArgOperand(1); +/// auto xprime = gutils.getNewFromOriginal(x); +/// auto yprime = gutils.getNewFromOriginal(y); +/// bool is_x_active = !gutils.isConstantValue(x); +/// bool is_y_active = !gutils.isConstantValue(y); +/// normalreturn = Builder.CreateCall(Intrinsic::pow, {x, y}); +/// if (is_x_active) { +/// auto ym1 = Builder.CreateFSub(y, ConstantFP::get(Type::getDoubleTy(CI->getContext()), 1.0)); +/// auto pow = Builder.CreateCall(Intrinsic::pow, {x, ym1}); +/// auto ypow = Builder.CreateFMul(y, pow); +/// shadowReturn = Builder.CreateFMul(xprime, ypow); +/// +/// // if y were inactive, this would be conceptually equivalent to generating +/// // define internal double @fwddiffetester(double %x, double %"x'", double %y) #1 { +/// // %0 = fsub fast double %y, 1.000000e+00 +/// // %1 = call fast double @llvm.pow.f64(double %x, double %0) +/// // %2 = fmul fast double %y, %1 +/// // %3 = fmul fast double %"x'", %2 +/// // ret double %3 +/// // } +/// } +/// if (is_y_active) { +/// auto pow = Builder.CreateCall(Intrinsic::pow, {x, y}); +/// auto log = Builder.CreateCall(Intrinsic::log, {x}); +/// auto logpow = Builder.CreateFMul(pow, log); +/// auto ylogpow = Builder.CreateFMul(yprime, logpow); +/// if (is_x_active) { +/// shadowReturn = Builder.CreateFAdd(ylogpow, shadowReturn); +/// } else { +/// shadowReturn = ylogpow; +/// } +/// +/// // if x was inactive, this would be conceptually equivalent to generating +/// // define internal double @fwddiffetester.1(double %x, double %y, double %"y'") #1 { +/// // %0 = call fast double @llvm.pow.f64(double %x, double %y) +/// // %1 = call fast double @llvm.log.f64(double %x) +/// // %2 = fmul fast double %0, %1 +/// // %3 = fmul fast double %"y'", %2 +/// // ret double %3 +/// // } +/// } +/// // We covered all 2x2 combinations, so always return true +/// return true; +/// } extern llvm::StringMap< std::function &, llvm::CallInst *, GradientUtils &, llvm::Value *&, llvm::Value *&)>>