Skip to content

Commit

Permalink
Switch to hydrogen for MPI calls except for Spectrum MPI sends
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 committed Apr 16, 2024
1 parent 1db91a2 commit e189333
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
1 change: 1 addition & 0 deletions cmake/configure_files/lbann_config.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#cmakedefine LBANN_HAS_SHMEM
#cmakedefine LBANN_HAS_LARGESCALE_NODE2VEC
#cmakedefine LBANN_HAS_ONNX
#cmakedefine LBANN_BUILT_WITH_SPECTRUM

#cmakedefine LBANN_DETERMINISTIC

Expand Down
28 changes: 16 additions & 12 deletions src/data_ingestion/readers/data_reader_python_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ void python_dataset_reader::shuffle_responses(DataType* responses_ptr)
// in a batch that can't be split evenly will be split evenly across the
// first n ranks (or subsets of ranks in the distconv case).

MPI_Comm trainer_comm = m_comm->get_trainer_comm().GetMPIComm();
uint64_t rank = m_comm->get_rank_in_trainer();
uint64_t nprocs = m_comm->get_procs_per_trainer();
uint64_t trainer_rank = m_comm->get_trainer_rank();
Expand All @@ -188,7 +187,7 @@ void python_dataset_reader::shuffle_responses(DataType* responses_ptr)
uint64_t distconv_extra_samples =
global_mb_size % (nprocs / num_io_partitions);

uint64_t send_rank, recv_rank, send_rank_count, recv_rank_count;
uint64_t send_rank, recv_rank, send_rank_count, recv_rank_count;
send_rank = recv_rank = send_rank_count = recv_rank_count = 0;
uint64_t send_rank_max_count =
local_distconv_mb_size + (distconv_extra_samples > 0);
Expand All @@ -201,24 +200,29 @@ void python_dataset_reader::shuffle_responses(DataType* responses_ptr)
m_num_responses * sizeof(DataType));
}
else {
#ifdef LBANN_BUILT_WITH_SPECTRUM
// Due to a potential bug in Spectrum MPI's send, we must use ssend to
// avoid hangs.
EL_CHECK_MPI_CALL(
MPI_Ssend(&responses_ptr[send_rank_count * m_num_responses],
m_num_responses * sizeof(DataType),
MPI_BYTE,
recv_rank,
m_comm->get_world_rank(trainer_rank, recv_rank),
0,
trainer_comm));
m_comm->get_world_comm().GetMPIComm()));
#else
m_comm->send(&responses_ptr[send_rank_count * m_num_responses],
m_num_responses,
trainer_rank,
recv_rank);
#endif
}
}
else if (rank == recv_rank) {
EL_CHECK_MPI_CALL(
MPI_Recv(&responses_ptr[send_rank_count * m_num_responses],
m_num_responses * sizeof(DataType),
MPI_BYTE,
send_rank,
0,
trainer_comm,
MPI_STATUS_IGNORE));
m_comm->recv(&responses_ptr[recv_rank_count * m_num_responses],
m_num_responses,
trainer_rank,
send_rank);
}

send_rank_count += 1;
Expand Down

0 comments on commit e189333

Please sign in to comment.