Skip to content

Commit

Permalink
更改mpi分布Delta矩阵的方式
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 27, 2023
1 parent 80b133a commit 1ccd571
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
44 changes: 42 additions & 2 deletions tetragono/tetragono/sampling_lattice/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,48 @@ def natural_gradient_by_conjugate_gradient(self, step, error):
mpi_gpu_comm.Gatherv(Delta, (gpu_Delta, gpu_Ns_list * Np)) # 收集

if exec_color == 1:
x, reason, result_step, result_error = tetraux.cg(gpu_Ns_total, Np, gpu_Energy, gpu_Delta, step, error,
gpu_color, mpi_exec_comm)
# 重新分发矩阵

# comm基本信息
exec_rank = mpi_exec_comm.Get_rank()
exec_size = mpi_exec_comm.Get_size()
# 所有进程都要知道每个进程的构型数目信息
Ns_list = np.zeros([exec_size], dtype=np.int32)
mpi_exec_comm.Allgather(np.array(gpu_Ns_total, dtype=np.int32), Ns_list)
Ns_total = sum(Ns_list)
# 所有进程都知道每个进程的Energy
all_Energy = np.zeros([Ns_total], dtype=dtype)
mpi_exec_comm.Allgatherv(gpu_Energy, (all_Energy, Ns_list))
# 切割Np
quotient = Np // exec_size
remainder = Np % exec_size
Np_list = np.zeros(exec_size, dtype=np.int32) + quotient
Np_list[:remainder] += 1
# 每个进程先行切割自己的Delta
Delta_pre = np.zeros([Np * Ns_list[exec_rank]], dtype=dtype)
begin = 0
end = 0
Ns_local = Ns_list[exec_rank]
for i in range(exec_size):
end += Np_list[i]
np.copyto(
Delta_pre[begin * Ns_local:end * Ns_local].reshape([Ns_local, Np_list[i]]),
gpu_Delta[:, begin:end],
)
begin = end
send_info = Ns_list[exec_rank] * Np_list
# 准备buff接受重新分发后的矩阵
Delta_redistributed = np.zeros([Ns_total, Np_list[exec_rank]], dtype=dtype)
recv_info = Ns_list * Np_list[exec_rank]
# 接收
mpi_exec_comm.Barrier() # 不加的话mpi会出错,似乎是mpi的一个bug
mpi_exec_comm.Alltoallv((Delta_pre, send_info), (Delta_redistributed, recv_info))

x_i, reason, result_step, result_error = tetraux.cg(Ns_total, Np_list[exec_rank], all_Energy,
Delta_redistributed, step, error, gpu_color,
mpi_exec_comm)
x = np.zeros([Np], dtype=dtype)
mpi_exec_comm.Allgatherv(x_i, (x, Np_list))
else:
x = np.zeros([Np], dtype=dtype)
reason = ""
Expand Down
16 changes: 9 additions & 7 deletions tetraux/tetraux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ auto cg(
real_base<Scalar> new_r_square;
real_base<Scalar> alpha;
real_base<Scalar> beta;
std::unique_ptr<Scalar[]> res_h(new Scalar[Np]);
std::unique_ptr<Scalar[]> res_h(new Scalar[Ns]);

// For complex scalar, conjugate on Delta is needed
if constexpr (!std::is_same_v<real_base<Scalar>, Scalar>) {
Expand All @@ -175,18 +175,18 @@ auto cg(
// res = Delta @ v
// Previous: N -> fix fortran order -> T -> manually conjugate -> C
gemv<Scalar>(handle.get(), blas_op_c, Np, Ns, &f_one, Delta.get(), Np, v.get(), i_one, &f_zero, res.get(), i_one);
// allreduce res
// TODO: For rocm aware mpi, it is possible to allreduce directly inside gpu.
res.to_host(res_h.get());
MPI_Allreduce(MPI_IN_PLACE, res_h.get(), Ns, mpi_datatype<Scalar>, MPI_SUM, comm);
res.from_host(res_h.get());
};
auto DT = [&](gpu_array<Scalar>& v, // Ns
gpu_array<Scalar>& res // Np
) {
// res = Delta.H @ v
// Previous: C -> fix fortran order -> C without T -> manually conjugate -> N
gemv<Scalar>(handle.get(), blas_op_n, Np, Ns, &f_one, Delta.get(), Np, v.get(), i_one, &f_zero, res.get(), i_one);
// allreduce res
// TODO: For rocm aware mpi, it is possible to allreduce directly inside gpu.
res.to_host(res_h.get());
MPI_Allreduce(MPI_IN_PLACE, res_h.get(), Np, mpi_datatype<Scalar>, MPI_SUM, comm);
res.from_host(res_h.get());
};

// CG
Expand All @@ -196,6 +196,7 @@ auto cg(
// b_square = b.H @ b
nrm2<Scalar>(handle.get(), Np, b.get(), i_one, &b_square);
b_square = b_square * b_square;
MPI_Allreduce(MPI_IN_PLACE, &b_square, 1, mpi_datatype<real_base<Scalar>>, MPI_SUM, comm);

// x = 0
scal<Scalar>(handle.get(), Np, &f_zero, x.get(), i_one);
Expand All @@ -206,6 +207,7 @@ auto cg(
// r_square = r.H @ r
nrm2<Scalar>(handle.get(), Np, r.get(), i_one, &r_square);
r_square = r_square * r_square;
MPI_Allreduce(MPI_IN_PLACE, &r_square, 1, mpi_datatype<real_base<Scalar>>, MPI_SUM, comm);

int t = 0;
const char* reason;
Expand All @@ -224,7 +226,6 @@ auto cg(
// alpha = r_square / allreduce(Dp.H @ Dp)
nrm2<Scalar>(handle.get(), Ns, Dp.get(), i_one, &alpha);
alpha = alpha * alpha;
MPI_Allreduce(MPI_IN_PLACE, &alpha, 1, mpi_datatype<real_base<Scalar>>, MPI_SUM, comm);
alpha = r_square / alpha;
// x = x + alpha * p
Scalar c_alpha = alpha;
Expand All @@ -236,6 +237,7 @@ auto cg(
// new_r_square = r.H @ r
nrm2<Scalar>(handle.get(), Np, r.get(), i_one, &new_r_square);
new_r_square = new_r_square * new_r_square;
MPI_Allreduce(MPI_IN_PLACE, &new_r_square, 1, mpi_datatype<real_base<Scalar>>, MPI_SUM, comm);
// beta = new_r_square / r_square
beta = new_r_square / r_square;
// r_square = new_r_square
Expand Down

0 comments on commit 1ccd571

Please sign in to comment.