Skip to content

Commit

Permalink
initial rust mpi support
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Jul 30, 2024
1 parent 3d97f17 commit efc8185
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
2 changes: 2 additions & 0 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ static const StringSet<> InactiveGlobals = {
"jl_small_typeof",
"ompi_request_null",
"ompi_mpi_double",
"RSMPI_DOUBLE",
"RSMPI_FLOAT",
"ompi_mpi_comm_world",
"__cxa_thread_atexit_impl",
"stderr",
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_double") {
auto name = GV->getName();
if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") {
return ConstantInt::get(intType, 8, false);
} else if (GV->getName() == "ompi_mpi_float") {
} else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") {
return ConstantInt::get(intType, 4, false);
}
}
Expand Down
7 changes: 4 additions & 3 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4785,11 +4785,12 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_double") {
auto name = GV->getName();
if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") {
buf.insert({0}, Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
} else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") {
buf.insert({0}, Type::getFloatTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_cxx_bool") {
} else if (name == "ompi_mpi_cxx_bool") {
buf.insert({0}, BaseType::Integer);
}
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2344,9 +2344,10 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_double") {
auto name = GV->getName();
if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") {
type = ConcreteType(Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
} else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") {
type = ConcreteType(Type::getFloatTy(C->getContext()));
}
}
Expand Down

0 comments on commit efc8185

Please sign in to comment.