From 98a36e8f5fdf3d60431b94ede2b2aa162e5c6142 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 14 Sep 2023 13:56:18 +0200 Subject: [PATCH] Use float dispatch macro in Gemm op --- src/ops/gemm.cc | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index bd7eb864b..e6ff87f9d 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -55,25 +55,15 @@ namespace ctranslate2 { break; case DataType::FLOAT32: - DEVICE_DISPATCH(a.device(), (compute(a, b, c, a_shift_compensation))); - break; - -#ifdef CT2_WITH_CUDA case DataType::FLOAT16: - if (a.device() != Device::CUDA) - throw std::invalid_argument("FP16 GEMM is only supported on GPU"); - compute(a, b, c, a_shift_compensation); + case DataType::BFLOAT16: { + DEVICE_AND_FLOAT_DISPATCH("Gemm", a.device(), a.dtype(), + (compute(a, b, c, a_shift_compensation))); break; - - case DataType::BFLOAT16: - if (a.device() != Device::CUDA) - throw std::invalid_argument("BF16 GEMM is only supported on GPU"); - compute(a, b, c, a_shift_compensation); - break; -#endif + } default: - throw std::invalid_argument("unsupported compute type " + dtype_name(a.dtype())); + throw std::invalid_argument("Gemm: unsupported input type " + dtype_name(a.dtype())); } apply_bias_and_activation(c, bias, _activation_type);