From f6c4d614e35b7424774160a23d8e8bef3b15faad Mon Sep 17 00:00:00 2001 From: aledudek Date: Wed, 18 Dec 2024 09:45:58 +0100 Subject: [PATCH] [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm (#1743) * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review changes * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review fix --- example/ck_tile/03_gemm/run_gemm_example.inc | 29 +++- .../run_batched_gemm_example.inc | 33 +++- .../ck_tile/host/reference/reference_gemm.hpp | 162 ++---------------- 3 files changed, 68 insertions(+), 156 deletions(-) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index a1fc155775..2b7a967bab 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + M * K * sizeof(ADataType), + hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + ck_tile::reference_gemm_gpu( - a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); + CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index dacca2042e..8345eef95b 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc, c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + batch_count * M * K * sizeof(ADataType), + hipMemcpyHostToDevice)); + + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + batch_count * N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + ck_tile::reference_batched_gemm_gpu(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_gpu_buf_ref, + CLayout>(d_A, + d_B, + d_C, M, N, K, @@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc, batch_stride_C, batch_count); + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + batch_count * M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 8bd1f5b048..fc412e8831 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -97,9 +97,9 @@ template -void reference_gemm_gpu(DeviceMem& a_device, - DeviceMem& b_device, - DeviceMem& c_device, +void reference_gemm_gpu(ADataType* a_ptr, + BDataType* b_ptr, + CDataType* c_ptr, index_t M, index_t N, index_t K, @@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device, index_t stride_b, index_t stride_c) { - - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType)); - hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType)); - hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType)); - if(errA != hipSuccess) - { - std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA) - << std::endl; - return; // Early exit on error - } - - if(errB != hipSuccess) - { - std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB) - << std::endl; - return; // Early exit on error - } - - if(errC != hipSuccess) - { - std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC) - << std::endl; - return; // Early exit on error - } - - errA = hipMemcpy( - d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice); - if(errA != hipSuccess) - { - std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipMemcpy( - d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice); - if(errB != hipSuccess) - { - std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl; - } - int totalElements = M * N; int numThreadsPerBlock = 256; // Common choice for threads per block int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; naive_gemm_kernel - <<>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); - errC = hipMemcpy( - c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); - if(errC != hipSuccess) - { - std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl; - } - - errA = hipFree(d_A); - if(errA != hipSuccess) - { - std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipFree(d_B); - if(errB != hipSuccess) - { - std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl; - } - - errC = hipFree(d_C); - if(errC != hipSuccess) - { - std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl; - } + <<>>( + a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c); return; } @@ -191,9 +125,9 @@ template -void reference_batched_gemm_gpu(DeviceMem& a_device, - DeviceMem& b_device, - DeviceMem& c_device, +void reference_batched_gemm_gpu(ADataType* a_ptr, + BDataType* b_ptr, + CDataType* c_ptr, index_t M, index_t N, index_t K, @@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device, index_t batch_stride_C, index_t batch_count) { - - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)); - hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)); - hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)); - if(errA != hipSuccess) - { - std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA) - << std::endl; - return; // Early exit on error - } - - if(errB != hipSuccess) - { - std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB) - << std::endl; - return; // Early exit on error - } - - if(errC != hipSuccess) - { - std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC) - << std::endl; - return; // Early exit on error - } - - errA = hipMemcpy(d_A, - a_device.GetDeviceBuffer(), - batch_count * M * K * sizeof(ADataType), - hipMemcpyHostToDevice); - if(errA != hipSuccess) - { - std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipMemcpy(d_B, - b_device.GetDeviceBuffer(), - batch_count * N * K * sizeof(BDataType), - hipMemcpyHostToDevice); - if(errB != hipSuccess) - { - std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl; - } - int totalElements = M * N; int numThreadsPerBlock = 256; // Common choice for threads per block int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; for(index_t batch_id = 0; batch_id < batch_count; ++batch_id) { - ADataType* d_ATemp = d_A + batch_id * batch_stride_A; - BDataType* d_BTemp = d_B + batch_id * batch_stride_B; - CDataType* d_CTemp = d_C + batch_id * batch_stride_C; + ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A; + BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B; + CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C; naive_gemm_kernel <<>>( d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c); } - errC = hipMemcpy(c_device.GetDeviceBuffer(), - d_C, - batch_count * M * N * sizeof(CDataType), - hipMemcpyDeviceToHost); - if(errC != hipSuccess) - { - std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl; - } - - errA = hipFree(d_A); - if(errA != hipSuccess) - { - std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipFree(d_B); - if(errB != hipSuccess) - { - std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl; - } - - errC = hipFree(d_C); - if(errC != hipSuccess) - { - std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl; - } - return; } } // namespace ck_tile