From 095ee7e3f42931360c3771a6767eea24b2e5c4e2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 02:15:45 -0500 Subject: [PATCH] Support batching scalar types (#2175) (#2180) * support batching scalar types * formatting Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com> --- enzyme/Enzyme/CallDerivatives.cpp | 4 ++++ enzyme/Enzyme/FunctionUtils.cpp | 3 ++- enzyme/Enzyme/GradientUtils.cpp | 8 ++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 243077c1376..22df3dab9a7 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3312,6 +3312,10 @@ bool AdjointGenerator::handleKnownCallDerivatives( } #endif Value *replacement = B.CreateAlloca(elTy, Size); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", + "enzymejl_allocart"}) + if (auto M = call.getMetadata(MD)) + cast(replacement)->setMetadata(MD, M); if (I) replacement->takeName(I); else diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 3cd0203c08a..1eac88e5b54 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -508,7 +508,8 @@ UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode, {ConstantAsMetadata::get(ConstantInt::get( IntegerType::get(AI->getContext(), 64), align))})); - for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type"}) + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", + "enzymejl_allocart"}) if (auto M = AI->getMetadata(MD)) CI->setMetadata(MD, M); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 8dfd7ae8104..f5d5ffa6854 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3280,6 +3280,10 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto replacement = NB.CreateAlloca( Type::getInt8Ty(I.getContext()), lookupM(getNewFromOriginal(I.getOperand(0)), NB, available)); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", + "enzymejl_allocart"}) + if (auto M = I.getMetadata(MD)) + replacement->setMetadata(MD, M); auto Alignment = cast( cast(MD->getOperand(0))->getValue()) @@ -3524,6 +3528,10 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto rule = [&](Value *anti) { AllocaInst *replacement = NB.CreateAlloca( Type::getInt8Ty(orig->getContext()), args[0]); + for (auto MD : {"enzyme_active", "enzyme_inactive", + "enzyme_type", "enzymejl_allocart"}) + if (auto M = I.getMetadata(MD)) + replacement->setMetadata(MD, M); replacement->takeName(anti); auto Alignment = cast(cast( MD->getOperand(0))