Skip to content

Commit

Permalink
Refactor LUgpuCHandle_interface_impl.cu, cublas_cusolver_wrappers.hpp…
Browse files Browse the repository at this point in the history
…, l_panels_impl.hpp, luAuxStructTemplated.hpp, lupanelsComm3dGPU_impl.hpp, lupanels_GPU_impl.hpp, lupanels_comm3d_impl.hpp, schurCompUpdate_impl.cuh, superlu_blas.hpp, u_panels_impl.hpp, xgstrf2.hpp, xlupanels.hpp, pzgssvx3d.c, and trfAux.c

Summary: Refactor various files related to LU factorization and GPU handling.

Files Changed:
1. SRC/TRF3dV100/LUgpuCHandle_interface_impl.cu: Refactored code related to LU factorization and GPU handling.
2. SRC/TRF3dV100/cublas_cusolver_wrappers.hpp: Updated function signatures and added support for doublecomplex type.
3. SRC/TRF3dV100/l_panels_impl.hpp: Updated function signatures and added support for doublecomplex type.
4. SRC/TRF3dV100/luAuxStructTemplated.hpp: Added template functions and operators for complex types.
5. SRC/TRF3dV100/lupanelsComm3dGPU_impl.hpp: Updated function signatures and added support for doublecomplex type.
6. SRC/TRF3dV100/lupanels_GPU_impl.hpp: Updated function signatures and added support for doublecomplex type.
7. SRC/TRF3dV100/lupanels_comm3d_impl.hpp: Updated function signatures and added support for doublecomplex type.
8. SRC/TRF3dV100/schurCompUpdate_impl.cuh: Updated function signatures and added support for doublecomplex type.
9. SRC/TRF3dV100/superlu_blas.hpp: Updated function signatures and added support for doublecomplex type.
10. SRC/TRF3dV100/u_panels_impl.hpp: Updated function signatures and added support for doublecomplex type.
11. SRC/TRF3dV100/xgstrf2.hpp: Updated function signatures and added support for doublecomplex type.
12. SRC/TRF3dV100/xlupanels.hpp: Updated function signatures and added support for doublecomplex type.
13. SRC/complex16/pzgssvx3d.c: Refactored code related to LU factorization and GPU handling.
14. SRC/prec-independent/trfAux.c: Refactored code related to LU factorization and GPU handling.
  • Loading branch information
piyush314 committed Jan 13, 2024
1 parent 21df571 commit a105e52
Show file tree
Hide file tree
Showing 14 changed files with 385 additions and 55 deletions.
59 changes: 59 additions & 0 deletions SRC/TRF3dV100/LUgpuCHandle_interface_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,63 @@ extern "C"
return LU_v1->pdgstrf3d();

}


// Double COmplex precision:
zLUgpu_Handle zCreateLUgpuHandle(int_t nsupers, int_t ldt_, ztrf3Dpartition_t *trf3Dpartition,
zLUstruct_t *LUstruct, gridinfo3d_t *grid3d,
SCT_t *SCT_, superlu_dist_options_t *options_, SuperLUStat_t *stat,
double thresh_, int *info_)
{
#if (DEBUGlevel >= 1)
CHECK_MALLOC(grid3d->iam, "Enter createLUgpuHandle");
#endif

xLUstruct_t<doublecomplex> *instance = new xLUstruct_t<doublecomplex>(nsupers, ldt_, trf3Dpartition,
LUstruct, grid3d,
SCT_, options_, stat,
thresh_, info_);

return reinterpret_cast<zLUgpu_Handle>(instance);
}

void zDestroyLUgpuHandle(zLUgpu_Handle LuH)
{
printf("\t... before delete luH\n");
fflush(stdout);
delete reinterpret_cast<xLUstruct_t<doublecomplex> *>(LuH);
printf("\t... after delete luH\n");
fflush(stdout);
}

// I think the following is not used
int zGatherFactoredLU3Dto2D(zLUgpu_Handle LuH);

int zCopyLUGPU2Host(zLUgpu_Handle LuH, zLUstruct_t *LUstruct)
{

xLUstruct_t<doublecomplex> *LU_v1 = reinterpret_cast<xLUstruct_t<doublecomplex> *>(LuH);
double tXferGpu2Host = SuperLU_timer_();
if (LU_v1->superlu_acc_offload)
{
#ifdef HAVE_CUDA
cudaStreamSynchronize(LU_v1->A_gpu.cuStreams[0]); // in theory I don't need it
LU_v1->copyLUGPUtoHost();
#endif
}

LU_v1->packedU2skyline(LUstruct);
tXferGpu2Host = SuperLU_timer_() - tXferGpu2Host;
printf("Time to send data back= %g\n", tXferGpu2Host);

return 0;
}

int pzgstrf3d_LUv1(zLUgpu_Handle LUHand) // pdgstrf3d_Upacked
{

xLUstruct_t<doublecomplex> *LU_v1 = reinterpret_cast<xLUstruct_t<doublecomplex> *>(LUHand);
return LU_v1->pdgstrf3d();

}
}
98 changes: 89 additions & 9 deletions SRC/TRF3dV100/cublas_cusolver_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,26 @@ template <typename Ftype>
cublasStatus_t myCublasTrsm(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const Ftype *alpha, const Ftype *A, int lda, Ftype *B, int ldb);

template <>
cublasStatus_t myCublasTrsm<double>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const double *alpha, const double *A, int lda, double *B, int ldb) {
cublasStatus_t myCublasTrsm<double>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const double *alpha, const double *A, int lda, double *B, int ldb)
{
return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
}

template <>
cublasStatus_t myCublasTrsm<float>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const float *alpha, const float *A, int lda, float *B, int ldb) {
cublasStatus_t myCublasTrsm<float>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const float *alpha, const float *A, int lda, float *B, int ldb)
{
return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
}

template <>
cublasStatus_t myCublasTrsm<cuComplex>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuComplex *alpha, const cuComplex *A, int lda, cuComplex *B, int ldb) {
cublasStatus_t myCublasTrsm<cuComplex>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuComplex *alpha, const cuComplex *A, int lda, cuComplex *B, int ldb)
{
return cublasCtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
}

template <>
cublasStatus_t myCublasTrsm<cuDoubleComplex>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb) {
cublasStatus_t myCublasTrsm<cuDoubleComplex>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb)
{
return cublasZtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
}

Expand Down Expand Up @@ -109,21 +113,97 @@ template <typename Ftype>
cublasStatus_t myCublasGemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const Ftype *alpha, const Ftype *A, int lda, const Ftype *B, int ldb, const Ftype *beta, Ftype *C, int ldc);

template <>
cublasStatus_t myCublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *A, int lda, const double *B, int ldb, const double *beta, double *C, int ldc) {
cublasStatus_t myCublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *A, int lda, const double *B, int ldb, const double *beta, double *C, int ldc)
{
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

template <>
cublasStatus_t myCublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *A, int lda, const float *B, int ldb, const float *beta, float *C, int ldc) {
cublasStatus_t myCublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *A, int lda, const float *B, int ldb, const float *beta, float *C, int ldc)
{
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

template <>
cublasStatus_t myCublasGemm<cuComplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex *alpha, const cuComplex *A, int lda, const cuComplex *B, int ldb, const cuComplex *beta, cuComplex *C, int ldc) {
cublasStatus_t myCublasGemm<cuComplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex *alpha, const cuComplex *A, int lda, const cuComplex *B, int ldb, const cuComplex *beta, cuComplex *C, int ldc)
{
return cublasCgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

template <>
cublasStatus_t myCublasGemm<cuDoubleComplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc) {
cublasStatus_t myCublasGemm<cuDoubleComplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc)
{
return cublasZgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
}

template <>
cublasStatus_t myCublasGemm<doublecomplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const doublecomplex *alpha, const doublecomplex *A, int lda, const doublecomplex *B, int ldb, const doublecomplex *beta, doublecomplex *C, int ldc)
{
// return cublasZgemm(handle, transa, transb, m, n, k,
// alpha, A, lda, B, ldb, beta, C, ldc);
// cast doublecomplex to cuDoubleComplex
return cublasZgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda,
reinterpret_cast<const cuDoubleComplex *>(B), ldb,
reinterpret_cast<const cuDoubleComplex *>(beta),
reinterpret_cast<cuDoubleComplex *>(C), ldc);

}

template <>
cusolverStatus_t myCusolverGetrf<doublecomplex>(
cusolverDnHandle_t handle, int m, int n, doublecomplex *A, int lda,
doublecomplex *Workspace, int *devIpiv, int *devInfo)
{
// return cusolverDnZgetrf(handle, m, n, A, lda, Workspace, devIpiv, devInfo);
// cast doublecomplex to cuDoubleComplex
return cusolverDnZgetrf(
handle, m, n, reinterpret_cast<cuDoubleComplex *>(A), lda,
reinterpret_cast<cuDoubleComplex *>(Workspace), devIpiv, devInfo);
}

// now creating the wrappers for the other functions
template <>
cublasStatus_t myCublasTrsm<doublecomplex>(cublasHandle_t handle,
cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag,
int m, int n,
const doublecomplex *alpha,
const doublecomplex *A, int lda,
doublecomplex *B, int ldb) {
// Your implementation here
// You can use cublasZtrsm function because it's for cuDoubleComplex type
return cublasZtrsm(handle, side, uplo, trans, diag, m, n,
reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<const cuDoubleComplex*>(A), lda,
reinterpret_cast<cuDoubleComplex*>(B), ldb);
}

template <>
cublasStatus_t myCublasScal<doublecomplex>(cublasHandle_t handle, int n,
const doublecomplex *alpha,
doublecomplex *x, int incx) {
// Your implementation here
// You can use cublasZscal function because it's for cuDoubleComplex type
return cublasZscal(handle, n, reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<cuDoubleComplex*>(x), incx);
}

template <>
cublasStatus_t myCublasAxpy<doublecomplex>(cublasHandle_t handle, int n,
const doublecomplex *alpha,
const doublecomplex *x, int incx,
doublecomplex *y, int incy) {
// Your implementation here
// You can use cublasZaxpy function because it's for cuDoubleComplex type
return cublasZaxpy(handle, n, reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<const cuDoubleComplex*>(x), incx,
reinterpret_cast<cuDoubleComplex*>(y), incy);
}


// cublasStatus_t myCublasScal<doublecomplex>
// cublasStatus_t myCublasAxpy<doublecomplex>
// cublasStatus_t myCublasGemm<doublecomplex>
5 changes: 3 additions & 2 deletions SRC/TRF3dV100/l_panels_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int_t xlpanel_t<Ftype>::panelSolve(int_t ksupsz, Ftype* DiagBlk, int_t LDD)
lPanelStPtr = blkPtr(1);
len -= nbrow(0);
}
Ftype alpha = 1.0; // {1.0, 0.0}; std::complex<double> alpha = {1.0, 0.0};
Ftype alpha = one<Ftype>(); // {1.0, 0.0}; std::complex<double> alpha = {1.0, 0.0};
superlu_trsm<Ftype>("R", "U", "N", "N",
len, ksupsz, alpha, DiagBlk, LDD,
lPanelStPtr, LDA());
Expand All @@ -83,7 +83,8 @@ int_t xlpanel_t<Ftype>::panelSolve(int_t ksupsz, Ftype* DiagBlk, int_t LDD)


template <typename Ftype>
int_t xlpanel_t<Ftype>::diagFactor(int_t k, Ftype* UBlk, int_t LDU, Ftype thresh, int_t *xsup,
int_t xlpanel_t<Ftype>::diagFactor(int_t k, Ftype* UBlk, int_t LDU,
threshPivValType<Ftype> thresh, int_t *xsup,
superlu_dist_options_t *options,
SuperLUStat_t *stat, int *info)
{
Expand Down
Loading

0 comments on commit a105e52

Please sign in to comment.