diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 60503c2b0fe..bc5a095cfae 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -30,7 +30,7 @@ using namespace llvm; extern "C" { void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, - LLVMValueRef) = nullptr; + LLVMValueRef, uint8_t) = nullptr; } void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, @@ -3062,9 +3062,12 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call), - idx, wrap(prev)); + idx, wrap(prev), used); + } } } if (Mode == DerivativeMode::ReverseModeCombined || @@ -3249,9 +3252,12 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx, - wrap(prev)); + wrap(prev), used); + } } idx++; prev = CI;