Skip to content

Commit

Permalink
Support cusparseLt v0.6.3 as backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
vin-huang committed Dec 20, 2024
1 parent ff3469f commit 7c2f219
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 36 deletions.
6 changes: 0 additions & 6 deletions clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,8 @@ struct perf_sparse<
Tc,
TBias,
std::enable_if_t<
#ifdef __HIP_PLATFORM_AMD__
(std::is_same<Ti, To>{} && (std::is_same<Ti, __half>{} || std::is_same<Ti, hip_bfloat16>{})
&& std::is_same<Tc, float>{})
#else
(std::is_same<Ti, To>{}
&& ((std::is_same<Ti, __half>{} && std::is_same<Tc, __half>{})
|| (std::is_same<Ti, hip_bfloat16>{} && std::is_same<Tc, hip_bfloat16>{})))
#endif
|| (std::is_same<Ti, To>{} && (std::is_same<Ti, int8_t>{}) && std::is_same<Tc, int32_t>{})
|| (std::is_same<Ti, int8_t>{} && (std::is_same<To, __half>{})
&& std::is_same<Tc, int32_t>{})
Expand Down
4 changes: 2 additions & 2 deletions clients/include/spmm/testing_spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ void testing_spmm(const Arguments& arg)
hD_gold_act + stride_d * i,
ldd,
tSizeD,
arg.alpha_vector_scaling ? hAlpahVector : nullptr,
arg.alpha_vector_scaling ? hAlpahVector : (float*)nullptr,
false);

auto pos = stride_d * i;
Expand Down Expand Up @@ -916,7 +916,7 @@ void testing_spmm(const Arguments& arg)
hD_gold + stride_d * i,
ldd,
tSizeD,
arg.alpha_vector_scaling ? hAlpahVector : nullptr,
arg.alpha_vector_scaling ? hAlpahVector : (float*)nullptr,
false);
}
#undef activation_param
Expand Down
4 changes: 2 additions & 2 deletions clients/include/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ inline hipsparseStatus_t expected_hipsparse_status_of_matrix_size(hipDataType
switch(type)
{
case HIP_R_8I:
case HIP_R_8F_E4M3_FNUZ:
case HIP_R_8F_E5M2_FNUZ:
case HIP_R_8F_E4M3:
case HIP_R_8F_E5M2:
if(isSparse)
row_ = col_ = ld_ = 32;
else
Expand Down
22 changes: 16 additions & 6 deletions clients/samples/example_spmm_strided_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,26 @@ inline bool AlmostEqual(__half a, __half b)
_Float16 data;
};

_HALF a_half = {__half_raw(a).x};
_HALF b_half = {__half_raw(b).x};
_HALF a_half = {__half_raw(a).x};
_HALF b_half = {__half_raw(b).x};
_HALF zero_half = {__half_raw(static_cast<__half>(0)).x};
_HALF one_half = {__half_raw(static_cast<__half>(1)).x};
_HALF e_n2_half = {__half_raw(static_cast<__half>(0.01)).x};

auto a_data = a_half.data;
auto b_data = b_half.data;
auto zero = zero_half.data;
auto one = one_half.data;
auto e_n2 = e_n2_half.data;
#else
auto a_data = a;
auto b_data = b;
auto zero = __half(0);
auto one = __half(1);
auto e_n2 = __half(0.01);
#endif
auto absA = (a_data > 0.0) ? a_data : static_cast<decltype(a_data)>(-a_data);
auto absB = (b_data > 0.0) ? b_data : static_cast<decltype(b_data)>(-b_data);
auto absA = (a_data > zero) ? a_data : static_cast<decltype(a_data)>(-a_data);
auto absB = (b_data > zero) ? b_data : static_cast<decltype(b_data)>(-b_data);
// this avoids NaN when inf is compared against inf in the alternative code
// path
if(static_cast<float>(absA) == std::numeric_limits<float>::infinity()
Expand All @@ -135,8 +145,8 @@ inline bool AlmostEqual(__half a, __half b)
{
return a_data == b_data;
}
auto absDiff = (a_data - b_data > 0) ? a_data - b_data : b_data - a_data;
return absDiff / (absA + absB + 1) < 0.01;
auto absDiff = (a_data - b_data > zero) ? a_data - b_data : b_data - a_data;
return absDiff / (absA + absB + one) < e_n2;
}

inline void extract_metadata(unsigned metadata, int& a, int& b, int& c, int& d)
Expand Down
5 changes: 3 additions & 2 deletions library/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ########################################################################
# Copyright (c) 2022 Advanced Micro Devices, Inc.
# Copyright (c) 2022-2024 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -142,7 +142,8 @@ if(NOT BUILD_CUDA)
# Target link libraries
target_link_libraries(hipsparselt PRIVATE hip::device ${DL_LIB})
else()
target_link_libraries(hipsparselt PRIVATE /usr/lib/x86_64-linux-gnu/libcusparseLt.so ${CUDA_CUSPARSE_LIBRARY})
find_library(CUDA_CUSPARSELT_LIBRARY NAMES cusparseLt PATHS /usr/lib/x86_64-linux-gnu /usr/local/cuda/lib64 REQUIRED)
target_link_libraries(hipsparselt PRIVATE ${CUDA_CUSPARSELT_LIBRARY} ${CUDA_CUSPARSE_LIBRARY})
endif()

# Target properties
Expand Down
1 change: 1 addition & 0 deletions library/include/hipsparselt.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ typedef enum {
When Input's datatype is FP16 - Bias type can be FP16 or FP32. (default FP16)
When Input's datatype is BF16 - Bias type can be BF16 or FP32. (default BF16)
In other cases - Bias type is FP32.*/
HIPSPARSELT_MATMUL_SPARSE_MAT_POINTER = 17, /**< Pointer to the pruned sparse matrix. */
} hipsparseLtMatmulDescAttribute_t;

/*! \ingroup types_module
Expand Down
5 changes: 5 additions & 0 deletions library/src/auxiliary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ const hipDataType string_to_hip_datatype(const std::string& value)
value == "f16_r" || value == "h" ? HIP_R_16F :
value == "bf16_r" ? HIP_R_16BF :
value == "i8_r" ? HIP_R_8I :
#ifdef __HIP_PLATFORM_AMD__
value == "f8_r" ? HIP_R_8F_E4M3_FNUZ :
value == "bf8_r" ? HIP_R_8F_E5M2_FNUZ :
#else
value == "f8_r" ? HIP_R_8F_E4M3 :
value == "bf8_r" ? HIP_R_8F_E5M2 :
#endif
static_cast<hipDataType>(-1);
}

Expand Down
17 changes: 13 additions & 4 deletions library/src/include/auxiliary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,15 @@ constexpr const char* hip_datatype_to_string(hipDataType type)
return "bf16_r";
case HIP_R_8I:
return "i8_r";
#ifdef __HIP_PLATFORM_AMD__
case HIP_R_8F_E4M3_FNUZ:
#endif
case HIP_R_8F_E4M3:
return "f8_r";
#ifdef __HIP_PLATFORM_AMD__
case HIP_R_8F_E5M2_FNUZ:
#endif
case HIP_R_8F_E5M2:
return "bf8_r";
default:
return "invalid";
Expand All @@ -165,10 +171,6 @@ constexpr const char* hipsparselt_computetype_to_string(hipsparseLtComputetype_t
return "i32_r";
case HIPSPARSELT_COMPUTE_32F:
return "f32_r";
case HIPSPARSELT_COMPUTE_TF32:
return "tf32_r";
case HIPSPARSELT_COMPUTE_TF32_FAST:
return "tf32f_r";
}
return "invalid";
}
Expand Down Expand Up @@ -223,6 +225,13 @@ __host__ __device__ inline bool hipsparselt_isnan(__half arg)
return (~x.x & 0x7c00) == 0 && (x.x & 0x3ff) != 0;
}

#ifdef __HIP_PLATFORM_NVIDIA__
__host__ __device__ inline bool hipsparselt_isnan(__nv_bfloat16 arg)
{
return __hisnan(arg);
}
#endif

/*******************************************************************************
* \brief returns true if arg is Infinity
********************************************************************************/
Expand Down
29 changes: 15 additions & 14 deletions library/src/nvcc_detail/hipsparselt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <cusparseLt.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define TO_STR2(x) #x
#define TO_STR(x) TO_STR2(x)
Expand Down Expand Up @@ -114,7 +115,7 @@ hipsparseStatus_t hipCUSPARSEStatusToHIPStatus(cusparseStatus_t cuStatus)
}

/* @deprecated */
cudaDataType HIPDatatypeToCuSparseLtDatatype(hipsparseLtDatatype_t type)
cudaDataType HIPSparseLtDatatypeToCuSparseLtDatatype(hipsparseLtDatatype_t type)
{
switch(type)
{
Expand Down Expand Up @@ -185,10 +186,8 @@ cusparseComputeType HIPComputetypeToCuSparseComputetype(hipsparseLtComputetype_t
return CUSPARSE_COMPUTE_16F;
case HIPSPARSELT_COMPUTE_32I:
return CUSPARSE_COMPUTE_32I;
case HIPSPARSELT_COMPUTE_TF32:
return CUSPARSE_COMPUTE_TF32;
case HIPSPARSELT_COMPUTE_TF32_FAST:
return CUSPARSE_COMPUTE_TF32_FAST;
case HIPSPARSELT_COMPUTE_32F:
return CUSPARSE_COMPUTE_32F;
default:
throw HIPSPARSE_STATUS_NOT_SUPPORTED;
}
Expand All @@ -202,10 +201,8 @@ hipsparseLtComputetype_t CuSparseLtComputetypeToHIPComputetype(cusparseComputeTy
return HIPSPARSELT_COMPUTE_16F;
case CUSPARSE_COMPUTE_32I:
return HIPSPARSELT_COMPUTE_32I;
case CUSPARSE_COMPUTE_TF32:
return HIPSPARSELT_COMPUTE_TF32;
case CUSPARSE_COMPUTE_TF32_FAST:
return HIPSPARSELT_COMPUTE_TF32_FAST;
case CUSPARSE_COMPUTE_32F:
return HIPSPARSELT_COMPUTE_32F;
default:
throw HIPSPARSE_STATUS_NOT_SUPPORTED;
}
Expand Down Expand Up @@ -312,6 +309,8 @@ cusparseLtMatmulDescAttribute_t
return CUSPARSELT_MATMUL_BIAS_STRIDE;
case HIPSPARSELT_MATMUL_BIAS_POINTER:
return CUSPARSELT_MATMUL_BIAS_POINTER;
case HIPSPARSELT_MATMUL_SPARSE_MAT_POINTER:
return CUSPARSELT_MATMUL_SPARSE_MAT_POINTER;
default:
throw HIPSPARSE_STATUS_NOT_SUPPORTED;
}
Expand Down Expand Up @@ -340,6 +339,8 @@ hipsparseLtMatmulDescAttribute_t
return HIPSPARSELT_MATMUL_BIAS_STRIDE;
case CUSPARSELT_MATMUL_BIAS_POINTER:
return HIPSPARSELT_MATMUL_BIAS_POINTER;
case CUSPARSELT_MATMUL_SPARSE_MAT_POINTER:
return HIPSPARSELT_MATMUL_SPARSE_MAT_POINTER;
default:
throw HIPSPARSE_STATUS_NOT_SUPPORTED;
}
Expand Down Expand Up @@ -531,7 +532,9 @@ hipsparseStatus_t hipsparseLtInit(hipsparseLtHandle_t* handle)
if((log_env = getenv("HIPSPARSELT_LOG_MASK")) != NULL)
{
int mask = strtol(log_env, nullptr, 0);
setenv("CUSPARSELT_LOG_MASK", std::to_string(mask).c_str(), 0);
char mask_str[11];
snprintf(mask_str, 11, "%d",mask);
setenv("CUSPARSELT_LOG_MASK", mask_str, 0);
}
if((log_env = getenv("HIPSPARSELT_LOG_FILE")) != NULL)
{
Expand Down Expand Up @@ -967,10 +970,8 @@ catch(...)
hipsparseStatus_t hipsparseLtGetArchName(char** archName)
try
{
*archName = nullptr;
std::string arch = "cuda";
*archName = (char*)malloc(arch.size() * sizeof(char));
strncpy(*archName, arch.c_str(), arch.size());
*archName = (char*)malloc(5);
snprintf(*archName, 5, "cuda\0");
return HIPSPARSE_STATUS_SUCCESS;
}
catch(...)
Expand Down

0 comments on commit 7c2f219

Please sign in to comment.