Skip to content

Commit

Permalink
Feature: Allow directly compiling CUDA version on DCU harware (#5727)
Browse files Browse the repository at this point in the history
* Initial commit

* Modify CMakeLists
  • Loading branch information
Critsium-xy authored Dec 13, 2024
1 parent 4891b2e commit ccd2874
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ option(ENABLE_CNPY "Enable cnpy usage." OFF)
option(ENABLE_PEXSI "Enable support for PEXSI." OFF)
option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF)
option(USE_DSP "Enable DSP usage." OFF)
option(USE_CUDA_ON_DCU "Enable CUDA on DCU" OFF)

# enable json support
if(ENABLE_RAPIDJSON)
Expand Down Expand Up @@ -126,6 +127,10 @@ if (USE_DSP)
set(ABACUS_BIN_NAME abacus_dsp)
endif()

if (USE_CUDA_ON_DCU)
add_compile_definitions(__CUDA_ON_DCU)
endif()

list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

if(ENABLE_COVERAGE)
Expand Down
2 changes: 1 addition & 1 deletion source/module_base/module_device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void record_device_memory(const Device* dev, std::ofstream& ofs_device, std::str
* @brief for compatibility with __CUDA_ARCH__ 600 and earlier
*
*/
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 && !defined(__CUDA_ON_DCU)
static __inline__ __device__ double atomicAdd(double* address, double val)
{
unsigned long long int* address_as_ull = (unsigned long long int*)address;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ void cal_force_npw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
int t_num = (npw%t_size) ? (npw/t_size + 1) : (npw/t_size);
dim3 npwgrid(((t_num%THREADS_PER_BLOCK) ? (t_num/THREADS_PER_BLOCK + 1) : (t_num/THREADS_PER_BLOCK)));

cal_force_npw << < npwgrid, THREADS_PER_BLOCK >> > (
cal_force_npw <<< npwgrid, THREADS_PER_BLOCK >>> (
reinterpret_cast<const thrust::complex<FPTYPE>*>(psiv),
gv_x, gv_y, gv_z, rhocgigg_vec, force, pos_x, pos_y, pos_z,
npw, omega, tpiba
Expand Down

0 comments on commit ccd2874

Please sign in to comment.