diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index e69145815ee..1e828629cb9 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -111,6 +111,8 @@ static const StringSet<> InactiveGlobals = { "jl_small_typeof", "ompi_request_null", "ompi_mpi_double", + "RSMPI_DOUBLE", + "RSMPI_FLOAT", "ompi_mpi_comm_world", "__cxa_thread_atexit_impl", "stderr", diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 19a832dd8e8..1b9fe136b6c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -145,9 +145,10 @@ class AdjointGenerator : public llvm::InstVisitor { C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { + auto name = GV->getName(); + if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") { return ConstantInt::get(intType, 8, false); - } else if (GV->getName() == "ompi_mpi_float") { + } else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") { return ConstantInt::get(intType, 4, false); } } diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index bf481c13b93..21fbe612daa 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4785,11 +4785,12 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { + auto name = GV->getName(); + if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") { buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { + } else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") { buf.insert({0}, Type::getFloatTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_cxx_bool") { + } else if (name == "ompi_mpi_cxx_bool") { buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 1e87bf2a9a0..7f29078b1af 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2344,9 +2344,10 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { + auto name = GV->getName(); + if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") { type = ConcreteType(Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { + } else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") { type = ConcreteType(Type::getFloatTy(C->getContext())); } }