Skip to content

Commit

Permalink
[TAT.hpp] Remove type for template alias of blas/lapack function.
Browse files Browse the repository at this point in the history
In fact, the type for template alias does not affect the
specialization, while in TAT's situation, it always use the
specialized alias, so, default type of template alias is useless, and
now it is removed. Besides, add `constexpr` specifier to the
specialization, since the template default specifier seems not affect
the specialized value.
  • Loading branch information
hzhangxyz committed Sep 20, 2023
1 parent a274000 commit 9ec7071
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 104 deletions.
50 changes: 10 additions & 40 deletions include/TAT/implement/contract.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,59 +162,29 @@ extern "C" {
namespace TAT {
namespace detail {
template<typename ScalarType>
constexpr int (*gemm)(
const char* transpose_a,
const char* transpose_b,
const int* m,
const int* n,
const int* k,
const ScalarType* alpha,
const ScalarType* a,
const int* lda,
const ScalarType* b,
const int* ldb,
const ScalarType* beta,
ScalarType* c,
const int* ldc
) = nullptr;
constexpr auto gemm = nullptr;

template<>
inline auto gemm<float> = sgemm_;
inline constexpr auto gemm<float> = sgemm_;
template<>
inline auto gemm<double> = dgemm_;
inline constexpr auto gemm<double> = dgemm_;
template<>
inline auto gemm<std::complex<float>> = cgemm_;
inline constexpr auto gemm<std::complex<float>> = cgemm_;
template<>
inline auto gemm<std::complex<double>> = zgemm_;
inline constexpr auto gemm<std::complex<double>> = zgemm_;

template<typename ScalarType>
constexpr int (*mkl_gemm_batch)(
const char* transpose_a,
const char* transpose_b,
const int* m,
const int* n,
const int* k,
const ScalarType* alpha,
const ScalarType** a,
const int* lda,
const ScalarType** b,
const int* ldb,
const ScalarType* beta,
ScalarType** c,
const int* ldc,
const int* group_count,
const int* group_size
) = nullptr;
constexpr auto mkl_gemm_batch = nullptr;

#ifdef TAT_USE_MKL_GEMM_BATCH
template<>
inline auto mkl_gemm_batch<float> = sgemm_batch_;
inline constexpr auto mkl_gemm_batch<float> = sgemm_batch_;
template<>
inline auto mkl_gemm_batch<double> = dgemm_batch_;
inline constexpr auto mkl_gemm_batch<double> = dgemm_batch_;
template<>
inline auto mkl_gemm_batch<std::complex<float>> = cgemm_batch_;
inline constexpr auto mkl_gemm_batch<std::complex<float>> = cgemm_batch_;
template<>
inline auto mkl_gemm_batch<std::complex<double>> = zgemm_batch_;
inline constexpr auto mkl_gemm_batch<std::complex<double>> = zgemm_batch_;
#endif
} // namespace detail

Expand Down
11 changes: 5 additions & 6 deletions include/TAT/implement/exponential.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@ extern "C" {
namespace TAT {
namespace detail {
template<typename ScalarType>
constexpr int (*gesv)(const int* n, const int* nrhs, ScalarType* A, const int* lda, int* ipiv, ScalarType* B, const int* ldb, int* info) =
nullptr;
constexpr auto gesv = nullptr;
template<>
inline auto gesv<float> = sgesv_;
inline constexpr auto gesv<float> = sgesv_;
template<>
inline auto gesv<double> = dgesv_;
inline constexpr auto gesv<double> = dgesv_;
template<>
inline auto gesv<std::complex<float>> = cgesv_;
inline constexpr auto gesv<std::complex<float>> = cgesv_;
template<>
inline auto gesv<std::complex<double>> = zgesv_;
inline constexpr auto gesv<std::complex<double>> = zgesv_;

template<typename ScalarType>
void linear_solve(int n, ScalarType* A, int nrhs, ScalarType* B, ScalarType* X) {
Expand Down
78 changes: 20 additions & 58 deletions include/TAT/implement/qr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,83 +126,45 @@ namespace TAT {

namespace detail {
template<typename ScalarType>
constexpr int (*geqrf)(
const int* m,
const int* n,
ScalarType* A,
const int* lda,
ScalarType* tau,
ScalarType* work,
const int* lwork,
int* info
) = nullptr;
constexpr auto geqrf = nullptr;
template<>
inline auto geqrf<float> = sgeqrf_;
inline constexpr auto geqrf<float> = sgeqrf_;
template<>
inline auto geqrf<double> = dgeqrf_;
inline constexpr auto geqrf<double> = dgeqrf_;
template<>
inline auto geqrf<std::complex<float>> = cgeqrf_;
inline constexpr auto geqrf<std::complex<float>> = cgeqrf_;
template<>
inline auto geqrf<std::complex<double>> = zgeqrf_;
inline constexpr auto geqrf<std::complex<double>> = zgeqrf_;
template<typename ScalarType>
constexpr int (*gelqf)(
const int* m,
const int* n,
ScalarType* A,
const int* lda,
ScalarType* tau,
ScalarType* work,
const int* lwork,
int* info
) = nullptr;
constexpr auto gelqf = nullptr;
template<>
inline auto gelqf<float> = sgelqf_;
inline constexpr auto gelqf<float> = sgelqf_;
template<>
inline auto gelqf<double> = dgelqf_;
inline constexpr auto gelqf<double> = dgelqf_;
template<>
inline auto gelqf<std::complex<float>> = cgelqf_;
inline constexpr auto gelqf<std::complex<float>> = cgelqf_;
template<>
inline auto gelqf<std::complex<double>> = zgelqf_;
inline constexpr auto gelqf<std::complex<double>> = zgelqf_;
template<typename ScalarType>
constexpr int (*orgqr)(
const int* m,
const int* n,
const int* k,
ScalarType* A,
const int* lda,
ScalarType* tau,
ScalarType* work,
const int* lwork,
int* info
) = nullptr;
constexpr auto orgqr = nullptr;
template<>
inline auto orgqr<float> = sorgqr_;
inline constexpr auto orgqr<float> = sorgqr_;
template<>
inline auto orgqr<double> = dorgqr_;
inline constexpr auto orgqr<double> = dorgqr_;
template<>
inline auto orgqr<std::complex<float>> = cungqr_;
inline constexpr auto orgqr<std::complex<float>> = cungqr_;
template<>
inline auto orgqr<std::complex<double>> = zungqr_;
inline constexpr auto orgqr<std::complex<double>> = zungqr_;
template<typename ScalarType>
constexpr int (*orglq)(
const int* m,
const int* n,
const int* k,
ScalarType* A,
const int* lda,
ScalarType* tau,
ScalarType* work,
const int* lwork,
int* info
) = nullptr;
constexpr auto orglq = nullptr;
template<>
inline auto orglq<float> = sorglq_;
inline constexpr auto orglq<float> = sorglq_;
template<>
inline auto orglq<double> = dorglq_;
inline constexpr auto orglq<double> = dorglq_;
template<>
inline auto orglq<std::complex<float>> = cunglq_;
inline constexpr auto orglq<std::complex<float>> = cunglq_;
template<>
inline auto orglq<std::complex<double>> = zunglq_;
inline constexpr auto orglq<std::complex<double>> = zunglq_;

template<typename ScalarType>
int to_int(const ScalarType& value) {
Expand Down

0 comments on commit 9ec7071

Please sign in to comment.