Skip to content

Commit

Permalink
Use float dispatch macro in Gemm op
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Sep 14, 2023
1 parent e2b0133 commit 98a36e8
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions src/ops/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,15 @@ namespace ctranslate2 {
break;

case DataType::FLOAT32:
DEVICE_DISPATCH(a.device(), (compute<D, float, float>(a, b, c, a_shift_compensation)));
break;

#ifdef CT2_WITH_CUDA
case DataType::FLOAT16:
if (a.device() != Device::CUDA)
throw std::invalid_argument("FP16 GEMM is only supported on GPU");
compute<Device::CUDA, float16_t, float16_t>(a, b, c, a_shift_compensation);
case DataType::BFLOAT16: {
DEVICE_AND_FLOAT_DISPATCH("Gemm", a.device(), a.dtype(),
(compute<D, T, T>(a, b, c, a_shift_compensation)));
break;

case DataType::BFLOAT16:
if (a.device() != Device::CUDA)
throw std::invalid_argument("BF16 GEMM is only supported on GPU");
compute<Device::CUDA, bfloat16_t, bfloat16_t>(a, b, c, a_shift_compensation);
break;
#endif
}

default:
throw std::invalid_argument("unsupported compute type " + dtype_name(a.dtype()));
throw std::invalid_argument("Gemm: unsupported input type " + dtype_name(a.dtype()));
}

apply_bias_and_activation(c, bias, _activation_type);
Expand Down

0 comments on commit 98a36e8

Please sign in to comment.