diff --git a/tetragono/tetragono/sampling_lattice/observer.py b/tetragono/tetragono/sampling_lattice/observer.py index 6efb18ab4..38da63429 100644 --- a/tetragono/tetragono/sampling_lattice/observer.py +++ b/tetragono/tetragono/sampling_lattice/observer.py @@ -637,8 +637,48 @@ def natural_gradient_by_conjugate_gradient(self, step, error): mpi_gpu_comm.Gatherv(Delta, (gpu_Delta, gpu_Ns_list * Np)) # 收集 if exec_color == 1: - x, reason, result_step, result_error = tetraux.cg(gpu_Ns_total, Np, gpu_Energy, gpu_Delta, step, error, - gpu_color, mpi_exec_comm) + # 重新分发矩阵 + + # comm基本信息 + exec_rank = mpi_exec_comm.Get_rank() + exec_size = mpi_exec_comm.Get_size() + # 所有进程都要知道每个进程的构型数目信息 + Ns_list = np.zeros([exec_size], dtype=np.int32) + mpi_exec_comm.Allgather(np.array(gpu_Ns_total, dtype=np.int32), Ns_list) + Ns_total = sum(Ns_list) + # 所有进程都知道每个进程的Energy + all_Energy = np.zeros([Ns_total], dtype=dtype) + mpi_exec_comm.Allgatherv(gpu_Energy, (all_Energy, Ns_list)) + # 切割Np + quotient = Np // exec_size + remainder = Np % exec_size + Np_list = np.zeros(exec_size, dtype=np.int32) + quotient + Np_list[:remainder] += 1 + # 每个进程先行切割自己的Delta + Delta_pre = np.zeros([Np * Ns_list[exec_rank]], dtype=dtype) + begin = 0 + end = 0 + Ns_local = Ns_list[exec_rank] + for i in range(exec_size): + end += Np_list[i] + np.copyto( + Delta_pre[begin * Ns_local:end * Ns_local].reshape([Ns_local, Np_list[i]]), + gpu_Delta[:, begin:end], + ) + begin = end + send_info = Ns_list[exec_rank] * Np_list + # 准备buff接受重新分发后的矩阵 + Delta_redistributed = np.zeros([Ns_total, Np_list[exec_rank]], dtype=dtype) + recv_info = Ns_list * Np_list[exec_rank] + # 接收 + mpi_exec_comm.Barrier() # 不加的话mpi会出错,似乎是mpi的一个bug + mpi_exec_comm.Alltoallv((Delta_pre, send_info), (Delta_redistributed, recv_info)) + + x_i, reason, result_step, result_error = tetraux.cg(Ns_total, Np_list[exec_rank], all_Energy, + Delta_redistributed, step, error, gpu_color, + mpi_exec_comm) + x = np.zeros([Np], dtype=dtype) + mpi_exec_comm.Allgatherv(x_i, (x, Np_list)) else: x = np.zeros([Np], dtype=dtype) reason = "" diff --git a/tetraux/tetraux.cpp b/tetraux/tetraux.cpp index 501021bc1..3f6c59086 100644 --- a/tetraux/tetraux.cpp +++ b/tetraux/tetraux.cpp @@ -160,7 +160,7 @@ auto cg( real_base new_r_square; real_base alpha; real_base beta; - std::unique_ptr res_h(new Scalar[Np]); + std::unique_ptr res_h(new Scalar[Ns]); // For complex scalar, conjugate on Delta is needed if constexpr (!std::is_same_v, Scalar>) { @@ -175,6 +175,11 @@ auto cg( // res = Delta @ v // Previous: N -> fix fortran order -> T -> manually conjugate -> C gemv(handle.get(), blas_op_c, Np, Ns, &f_one, Delta.get(), Np, v.get(), i_one, &f_zero, res.get(), i_one); + // allreduce res + // TODO: For rocm aware mpi, it is possible to allreduce directly inside gpu. + res.to_host(res_h.get()); + MPI_Allreduce(MPI_IN_PLACE, res_h.get(), Ns, mpi_datatype, MPI_SUM, comm); + res.from_host(res_h.get()); }; auto DT = [&](gpu_array& v, // Ns gpu_array& res // Np @@ -182,11 +187,6 @@ auto cg( // res = Delta.H @ v // Previous: C -> fix fortran order -> C without T -> manually conjugate -> N gemv(handle.get(), blas_op_n, Np, Ns, &f_one, Delta.get(), Np, v.get(), i_one, &f_zero, res.get(), i_one); - // allreduce res - // TODO: For rocm aware mpi, it is possible to allreduce directly inside gpu. - res.to_host(res_h.get()); - MPI_Allreduce(MPI_IN_PLACE, res_h.get(), Np, mpi_datatype, MPI_SUM, comm); - res.from_host(res_h.get()); }; // CG @@ -196,6 +196,7 @@ auto cg( // b_square = b.H @ b nrm2(handle.get(), Np, b.get(), i_one, &b_square); b_square = b_square * b_square; + MPI_Allreduce(MPI_IN_PLACE, &b_square, 1, mpi_datatype>, MPI_SUM, comm); // x = 0 scal(handle.get(), Np, &f_zero, x.get(), i_one); @@ -206,6 +207,7 @@ auto cg( // r_square = r.H @ r nrm2(handle.get(), Np, r.get(), i_one, &r_square); r_square = r_square * r_square; + MPI_Allreduce(MPI_IN_PLACE, &r_square, 1, mpi_datatype>, MPI_SUM, comm); int t = 0; const char* reason; @@ -224,7 +226,6 @@ auto cg( // alpha = r_square / allreduce(Dp.H @ Dp) nrm2(handle.get(), Ns, Dp.get(), i_one, &alpha); alpha = alpha * alpha; - MPI_Allreduce(MPI_IN_PLACE, &alpha, 1, mpi_datatype>, MPI_SUM, comm); alpha = r_square / alpha; // x = x + alpha * p Scalar c_alpha = alpha; @@ -236,6 +237,7 @@ auto cg( // new_r_square = r.H @ r nrm2(handle.get(), Np, r.get(), i_one, &new_r_square); new_r_square = new_r_square * new_r_square; + MPI_Allreduce(MPI_IN_PLACE, &new_r_square, 1, mpi_datatype>, MPI_SUM, comm); // beta = new_r_square / r_square beta = new_r_square / r_square; // r_square = new_r_square