From c1b6729d4767cc42b6fe57d7ce2cf8c8c010d41c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 11 Aug 2024 15:01:55 -0400 Subject: [PATCH] JLInstSimplify multi arg --- enzyme/Enzyme/JLInstSimplify.cpp | 113 ++++++++++++++++++++--- enzyme/test/Enzyme/JLSimplify/yesptr2.ll | 26 ++++++ 2 files changed, 125 insertions(+), 14 deletions(-) create mode 100644 enzyme/test/Enzyme/JLSimplify/yesptr2.ll diff --git a/enzyme/Enzyme/JLInstSimplify.cpp b/enzyme/Enzyme/JLInstSimplify.cpp index c04132498790..e8420e251457 100644 --- a/enzyme/Enzyme/JLInstSimplify.cpp +++ b/enzyme/Enzyme/JLInstSimplify.cpp @@ -129,6 +129,58 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst) { return true; } +static inline SetVector getBaseObjects(llvm::Value *V, + bool offsetAllowed) { + SetVector results; + + SmallPtrSet seen; + SmallVector todo = {V}; + + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (seen.count(cur)) + continue; + seen.insert(cur); + auto obj = getBaseObject(cur, offsetAllowed); + if (auto PN = dyn_cast(obj)) { + for (auto &val : PN->incoming_values()) { + todo.push_back(val); + } + continue; + } + if (auto SI = dyn_cast(obj)) { + todo.push_back(SI->getTrueValue()); + todo.push_back(SI->getFalseValue()); + continue; + } + results.insert(obj); + } + return results; +} + +bool noaliased_or_arg(SetVector &lhs_v, + SetVector &rhs_v) { + for (auto lhs : lhs_v) { + auto lhs_na = isNoAlias(lhs); + auto lhs_arg = isa(lhs); + + // This LHS value is neither noalias or an argument + if (!lhs_na && !lhs_arg) + return false; + + for (auto rhs : rhs_v) { + if (lhs == rhs) + return false; + if (isNoAlias(lhs)) + continue; + if (!lhs_na && !isa(rhs)) + return false; + } + } + return true; +} + bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI) { bool changed = false; @@ -175,9 +227,9 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, } if (legal) { - auto lhs = getBaseObject(I.getOperand(0), /*offsetAllowed*/ false); - auto rhs = getBaseObject(I.getOperand(1), /*offsetAllowed*/ false); - if (lhs == rhs) { + auto lhs_v = getBaseObjects(I.getOperand(0), /*offsetAllowed*/ false); + auto rhs_v = getBaseObjects(I.getOperand(1), /*offsetAllowed*/ false); + if (lhs_v.size() == 1 && rhs_v.size() == 1 && lhs_v[0] == rhs_v[0]) { auto repval = ICmpInst::isTrueWhenEqual(pred) ? ConstantInt::get(I.getType(), 1) : ConstantInt::get(I.getType(), 0); @@ -185,8 +237,7 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, changed = true; continue; } - if ((isNoAlias(lhs) && (isNoAlias(rhs) || isa(rhs))) || - (isNoAlias(rhs) && isa(lhs))) { + if (noaliased_or_arg(lhs_v, rhs_v)) { auto repval = ICmpInst::isTrueWhenEqual(pred) ? ConstantInt::get(I.getType(), 0) : ConstantInt::get(I.getType(), 1); @@ -194,14 +245,41 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, changed = true; continue; } - auto llhs = dyn_cast(lhs); - auto lrhs = dyn_cast(rhs); - if (llhs && lrhs && isa(llhs->getType()) && - isa(lrhs->getType())) { - auto lhsv = - getBaseObject(llhs->getOperand(0), /*offsetAllowed*/ false); - auto rhsv = - getBaseObject(lrhs->getOperand(0), /*offsetAllowed*/ false); + bool loadlegal = true; + SmallVector llhs, lrhs; + for (auto lhs : lhs_v) { + auto ld = dyn_cast(lhs); + if (!ld || !isa(ld->getType())) { + loadlegal = false; + break; + } + llhs.push_back(ld); + } + for (auto rhs : rhs_v) { + auto ld = dyn_cast(rhs); + if (!ld || !isa(ld->getType())) { + loadlegal = false; + break; + } + lrhs.push_back(ld); + } + SetVector llhs_s, lrhs_s; + for (auto v : llhs) { + for (auto obj : + getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) { + llhs_s.insert(obj); + } + } + for (auto v : lrhs) { + for (auto obj : + getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) { + lrhs_s.insert(obj); + } + } + // TODO handle multi size + if (llhs_s.size() == 1 && lrhs_s.size() == 1 && loadlegal) { + auto lhsv = llhs_s[0]; + auto rhsv = lrhs_s[0]; if ((isNoAlias(lhsv) && (isNoAlias(rhsv) || isa(rhsv) || notCapturedBefore(lhsv, &I))) || (isNoAlias(rhsv) && @@ -225,7 +303,14 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, if (!I->mayWriteToMemory()) return /*earlyBreak*/ false; - for (auto LI : {llhs, lrhs}) + for (auto LI : llhs) + if (writesToMemoryReadBy(AA, TLI, + /*maybeReader*/ LI, + /*maybeWriter*/ I)) { + overwritten = true; + return /*earlyBreak*/ true; + } + for (auto LI : lrhs) if (writesToMemoryReadBy(AA, TLI, /*maybeReader*/ LI, /*maybeWriter*/ I)) { diff --git a/enzyme/test/Enzyme/JLSimplify/yesptr2.ll b/enzyme/test/Enzyme/JLSimplify/yesptr2.ll new file mode 100644 index 000000000000..904985866fe4 --- /dev/null +++ b/enzyme/test/Enzyme/JLSimplify/yesptr2.ll @@ -0,0 +1,26 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -jl-inst-simplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="jl-inst-simplify" -S | FileCheck %s + +declare i8** @malloc(i64) + +define fastcc i1 @augmented_julia__affine_normalize_1484(i1 %c) { + %i5 = call noalias i8** @malloc(i64 16) + br i1 %c, label %tval, label %fval + +tval: + %j29 = load i8*, i8** %i5, align 8 + br label %end + +fval: + %k29 = load i8*, i8** %i5, align 8 + br label %end + +end: + %i29 = phi i8* [ %j29, %tval ], [ %k29, %fval ] + %i31 = call noalias nonnull i8* addrspace(10)* inttoptr (i64 137352001798896 to i8* addrspace(10)* ({} addrspace(10)*, i64, i64)*)({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 137351863426640 to {}*) to {} addrspace(10)*), i64 10, i64 10) + %i35 = load i8*, i8* addrspace(10)* %i31, align 8 + %i39 = icmp ne i8* %i35, %i29 + ret i1 %i39 +} + +; CHECK: ret i1 true