diff --git a/CMakeLists.txt b/CMakeLists.txt index 27b6b150e7..5743ff78e3 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,10 +17,10 @@ project(TurboMind LANGUAGES CXX CUDA) find_package(CUDA 10.2 REQUIRED) -# if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") -# add_definitions("-DENABLE_BF16") -# message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag") -# endif() +if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") + add_definitions("-DENABLE_BF16") + message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag") +endif() # if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12")) # add_definitions("-DENABLE_FP8") diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 5bcab7b537..02a9957f58 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -6,6 +6,7 @@ from pathlib import Path import fire +import torch from huggingface_hub import snapshot_download from lmdeploy.model import MODELS @@ -113,6 +114,55 @@ def copy_tokenizer(model_path: str, tokenizer_path: str, osp.join(triton_models_path, 'tokenizer')) +def update_output_format(model_name: str, model_format: str, model_path: str, + output_format: str): + """Update output format according to model info.""" + TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16'} + MODEL_NAME_MAP = {'qwen': 'bf16', 'llama': 'half'} + model_name = model_name.split('-')[0] + + def _fix_device_support(output_format): + """fix device support.""" + if output_format == 'bf16': + if not torch.cuda.is_bf16_supported(): + # device does not support bf16 + print('Device does not support bf16.') + output_format = 'fp16' + return output_format + + def _infer_output_format(config): + """_infer_output_format.""" + torch_dtype = getattr(config, 'torch_dtype', None) + if torch_dtype: + updated_output_format = TORCH_DTYPE_MAP.get( + torch_dtype, output_format) + else: + # get model name prefix + updated_output_format = MODEL_NAME_MAP.get(model_name, + output_format) + return _fix_device_support(updated_output_format) + + if model_format in MODEL_NAME_MAP: + updated_output_format = MODEL_NAME_MAP.get(model_name, output_format) + return _fix_device_support(updated_output_format) + else: + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + return _infer_output_format(config) + + +def update_config_weight_type(output_format: str, + config: TurbomindModelConfig): + WEIGHT_TYPE_MAP = { + 'fp32': 'fp32', + 'fp16': 'fp16', + 'bf16': 'bf16', + 'w4': 'int4', + 'w8': 'int8' + } + config.weight_type = WEIGHT_TYPE_MAP[output_format] + + def pack_model_repository(workspace_path: str): """package the model repository. @@ -215,6 +265,10 @@ def main(model_name: str, cfg.weight_type = 'int4' output_format = 'w4' assert group_size > 0, f'group_size: {group_size} should > 0' + else: + output_format = update_output_format(model_name, inferred_model_format, + model_path, output_format) + update_config_weight_type(output_format, cfg) # convert print('model_name ', model_name) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 92e6232301..a1ca79b98d 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -83,6 +83,14 @@ def valid(self): return True +_WEIGHT_DTYPE_MAP = dict( + int4=torch.float16, + fp16=torch.float16, + fp32=torch.float16, + bf16=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, +) + + class BaseOutputModel(ABC): """Base output model.""" @@ -136,19 +144,26 @@ def export_config(self) -> None: def export_weight(self, param: torch.Tensor, name: str) -> None: """export turbomind weight.""" + + def _tofile(tensor, path): + """to file.""" + if tensor.dtype == torch.bfloat16: + tensor = tensor.view(torch.half) + tensor.contiguous().cpu().numpy().tofile(path) + if self.to_file: - if param.dtype in [torch.float, torch.bfloat16]: - param = param.half() + torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type, + torch.float16) + param = param.to(torch_type) tprint(name, param.shape) - param.contiguous().cpu().numpy().tofile( - osp.join(self.out_dir, name)) + _tofile(param, osp.join(self.out_dir, name)) elif len(self.tm_params) > 0: tm_params = self.tm_params weight_type = self.cfg.weight_type - assert weight_type in ['fp16', 'fp32', 'int4'] + assert weight_type in ['fp16', 'fp32', 'bf16', 'int4'] # currently, the tensor type should in - # [torch.float, torch.half, torch.int32] + # [torch.float, torch.half, torch.bfloat16, torch.int32] torch_tensor = param.cuda().contiguous() assert torch_tensor.dtype in [ torch.int32, torch.float, torch.half, torch.bfloat16 @@ -156,6 +171,8 @@ def export_weight(self, param: torch.Tensor, name: str) -> None: if torch_tensor.dtype != torch.int32: if weight_type in ['fp16', 'int4']: torch_tensor = torch_tensor.half() + elif weight_type == 'bf16': + torch_tensor = torch_tensor.bfloat16() else: torch_tensor = torch_tensor.float() for tm_tensor in tm_params[name]: diff --git a/lmdeploy/turbomind/deploy/target_model/fp.py b/lmdeploy/turbomind/deploy/target_model/fp.py index d9a7783436..ebebc2fae8 100644 --- a/lmdeploy/turbomind/deploy/target_model/fp.py +++ b/lmdeploy/turbomind/deploy/target_model/fp.py @@ -14,7 +14,7 @@ def transpose_tensor(input: List[torch.Tensor]): return output -@OUTPUT_MODELS.register_module(name='fp16') +@OUTPUT_MODELS.register_module(name=['fp16', 'bf16']) class TurbomindModel(BaseOutputModel): """Export to turbomind fp16 format.""" diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index ad7c0cb518..3d7522462f 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -22,7 +22,8 @@ from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger -from .deploy.converter import get_model_format, supported_formats +from .deploy.converter import (get_model_format, supported_formats, + update_config_weight_type, update_output_format) from .deploy.source_model.base import INPUT_MODELS from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig from .utils import (ModelSource, check_tm_model_input, create_hf_download_args, @@ -246,6 +247,12 @@ def _from_hf(self, output_format = 'w4' data_type = 'int4' assert group_size > 0, f'group_size: {group_size} should > 0' + else: + output_format = update_output_format(model_name, + inferred_model_format, + model_path, output_format) + data_type = output_format + update_config_weight_type(output_format, cfg) self.config = cfg self.model_name = model_name diff --git a/src/turbomind/kernels/activation_kernels.cu b/src/turbomind/kernels/activation_kernels.cu index 97bd7961d1..69b10c78c0 100644 --- a/src/turbomind/kernels/activation_kernels.cu +++ b/src/turbomind/kernels/activation_kernels.cu @@ -117,7 +117,11 @@ struct ReluActivation<__nv_bfloat162> { static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val) { const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + return turbomind::make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16); +#else return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16); +#endif } }; #endif diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh index 85ece1fa99..e853c242ec 100644 --- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh +++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -590,6 +590,39 @@ __inline__ __device__ uint4 vec_conversion(const Float8_& a) return b; } +#ifdef ENABLE_BF16 +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) +{ + return cuda_cast<__nv_bfloat162, float2>(a); +} +template<> +__inline__ __device__ bf16_4_t vec_conversion(const float4& a) +{ + bf16_4_t b; + float2 val; + val.x = a.x; + val.y = a.y; + b.x = vec_conversion<__nv_bfloat162, float2>(val); + + val.x = a.z; + val.y = a.w; + b.y = vec_conversion<__nv_bfloat162, float2>(val); + + return b; +} +template<> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) +{ + bf16_8_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + b.z = vec_conversion<__nv_bfloat162, float2>(a.z); + b.w = vec_conversion<__nv_bfloat162, float2>(a.w); + return b; +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -1053,6 +1086,55 @@ inline __device__ int64_t quant(uint4 a, const float scale, const float zp) int16[3] = quant(a.w, scale, zp); return int64; } + +// bfloat16 to int8 +inline __device__ int8_t quant(__nv_bfloat16 a, const float scale, const float zp) +{ + int8_t int8; + float b = bfloat16_to_float(a); + int8 = round(max(-128.f, min(127.f, (b - zp) / scale))); + return int8; +} +// bfloat16x2 to int8x2 +inline __device__ int16_t quant(__nv_bfloat162 a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + float2 b = bfloat162_to_float2(a); + + int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale))); + return int16; +} +// bfloat16x4 to int8x4 +inline __device__ int32_t quant(bf16_4_t a, const float scale, const float zp) +{ + union { + int16_t int16[2]; + int32_t int32; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + return int32; +} +// bfloat16x8 to int8x8 +inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + int16[2] = quant(a.z, scale, zp); + int16[3] = quant(a.w, scale, zp); + return int64; +} + // int8 to float32, then `vec_conversion` to target format inline __device__ float dequant(int8_t a, const float scale, const float zp) { diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention_utils.h b/src/turbomind/kernels/decoder_masked_multihead_attention_utils.h index 6479647799..99f68648aa 100644 --- a/src/turbomind/kernels/decoder_masked_multihead_attention_utils.h +++ b/src/turbomind/kernels/decoder_masked_multihead_attention_utils.h @@ -43,6 +43,24 @@ struct Float4_ { //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template struct num_elems; template<> @@ -79,6 +97,21 @@ struct num_elems { static constexpr int value = 8; }; +#ifdef ENABLE_BF16 +template<> +struct num_elems<__nv_bfloat162> { + static constexpr int value = 2; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 8; +}; +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -144,6 +177,44 @@ inline __device__ float4 add(float4 a, float4 b) //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) +{ + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ uint16_t add(uint16_t a, uint16_t b) { uint16_t c; @@ -236,6 +307,26 @@ inline __device__ float2 half2_to_float2(uint32_t v) //////////////////////////////////////////////////////////////////////////////////////////////////// +inline __device__ float bfloat16_to_float(__nv_bfloat16 h) +{ + return __bfloat162float(h); + // float f; + // asm volatile("cvt.f32.bf16 %0, %1;\n" : "=f"(f) : "h"(h)); + // return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 bfloat162_to_float2(__nv_bfloat162 v) +{ + return cuda_cast(v); + // __nv_bfloat16 lo, hi; + // asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + // return make_float2(bfloat16_to_float(lo), bfloat16_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float add(float a, uint16_t b) { return a + half_to_float(b); @@ -243,6 +334,15 @@ inline __device__ float add(float a, uint16_t b) //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ float add(float a, __nv_bfloat16 b) +{ + return a + __bfloat162float(b); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float2 add(uint32_t a, float2 fb) { float2 fa = half2_to_float2(a); @@ -367,6 +467,37 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) +{ + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) +{ + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) +{ + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; @@ -498,6 +629,134 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) return fd; } +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(a, b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(bf162bf162(a), b, c); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) +{ + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) +{ + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) +{ + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) +{ + return fma(bf162bf162(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) +{ + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) +{ + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} +#endif // ENABLE_BF16 //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -746,6 +1005,171 @@ inline __device__ float sum(float v) //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hmul(a, b); +#else + return bf16hmul(a, b); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hmul2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ + float fa = (float)a; + float fb = (float)b; + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(__nv_bfloat16 a, float b) +{ + return __bfloat162float(a) * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) +{ + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) +{ + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float sum(float2 v) { return v.x + v.y; @@ -758,6 +1182,31 @@ inline __device__ float sum(float4 v) return v.x + v.y + v.z + v.w; } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float sum(__nv_bfloat162 v) +{ + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_4_t v) +{ + return sum(v.x) + sum(v.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_8_t v) +{ + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float sum(uint16_t v) { return half_to_float(v); @@ -1048,4 +1497,85 @@ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int r k.w = rotary_embedding_transform(k.w, coef3); } +#ifdef ENABLE_BF16 +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, float base, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void +apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, float base, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, float base, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void +apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, float base, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, float base, int t_step) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, t_step); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, t_step); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void +apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, float base, int t_step) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, t_step); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, t_step); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} +#endif // ENABLE_BF16 } // namespace mmha diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h index 5a1300ff2d..96a43de1e5 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h +++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h @@ -3,6 +3,7 @@ #pragma once #include "src/turbomind/kernels/gemm_s_f16/common.h" +#include "src/turbomind/utils/cuda_bf16_wrapper.h" #include #include @@ -487,4 +488,34 @@ struct ConvertKvCache { } }; +#ifdef ENABLE_BF16 +template<> +struct ConvertKvCache { + + float scale_; + float zero_; + + __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) + { + zero_ = zero_ - 32896.f * scale_; + } + + template + inline __device__ auto operator()(const Array& vi) const -> Array<__nv_bfloat16, N> + { + Array<__nv_bfloat16, N> vo; + PRAGMA_UNROLL + for (int i = 0; i < N; i += 4) { + auto& vec = (Array<__nv_bfloat16, 4>&)vo[i]; + auto tmp = fast_i2f_f32_s8((const Array&)vi[i]); + PRAGMA_UNROLL + for (int j = 0; j < 4; ++j) { + vec[j] = __nv_bfloat16(tmp[j] * scale_ + zero_); + // vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_)); + } + } + return vo; + } +}; +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu index 02cc827694..bb7c54e7f0 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu +++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu @@ -107,9 +107,26 @@ void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& invokeDecoderMultiheadAttention(params); } } + +#ifdef ENABLE_BF16 + if constexpr (std::is_same_v) { + int group_size = params.num_heads / params.num_kv_heads; + if (group_size % 4 == 0) { + invokeDecoderMultiheadAttention(params); + } + else if (group_size % 2 == 0) { + invokeDecoderMultiheadAttention(params); + } + else { + invokeDecoderMultiheadAttention(params); + } + } +#endif // ENABLE_BF16 } template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params); template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params); - +#ifdef ENABLE_BF16 +template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<__nv_bfloat16>& params); +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu index d9a46c40a7..4fec0eadb6 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu +++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu @@ -477,5 +477,21 @@ template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs, int quant_policy, const float* kv_params, cudaStream_t st); - +#ifdef ENABLE_BF16 +template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs, + const void** src_v_block_ptrs, + __nv_bfloat16** dst_k_ptrs, + __nv_bfloat16** dst_v_ptrs, + const int* src_cu_block_cnts, + const int* seq_lens, + int src_offset, + int src_block_len, + int dst_block_len, + int head_num, + int head_dim, + int batch_size, + int quant_policy, + const float* kv_params, + cudaStream_t st); +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 90f303a8bf..3ead17bc82 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1555,5 +1555,8 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) template class LlamaBatch; template class LlamaBatch; +#ifdef ENABLE_BF16 +template class LlamaBatch<__nv_bfloat16>; +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index ab3eb783c4..34c0abf86d 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -377,5 +377,8 @@ TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix) template struct LlamaDecoderLayerWeight; template struct LlamaDecoderLayerWeight; +#ifdef ENABLE_BF16 +template struct LlamaDecoderLayerWeight<__nv_bfloat16>; +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index 410667f1bd..369f26c736 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -30,6 +30,7 @@ enum class WeightType : int kFP32, kFP16, kFP8, // not supported yet + kBF16, kINT8, kINT4 }; @@ -43,6 +44,8 @@ inline size_t getBitSize(WeightType type) return 16; case WeightType::kFP8: return 8; + case WeightType::kBF16: + return 16; case WeightType::kINT8: return 8; case WeightType::kINT4: diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc index 0d78dc4e80..42575af665 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.cc +++ b/src/turbomind/models/llama/LlamaFfnLayer.cc @@ -125,5 +125,8 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, template class LlamaFfnLayer; template class LlamaFfnLayer; +#ifdef ENABLE_BF16 +template class LlamaFfnLayer<__nv_bfloat16>; +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaLinear.h b/src/turbomind/models/llama/LlamaLinear.h index 0e783df33d..a3717b2a90 100644 --- a/src/turbomind/models/llama/LlamaLinear.h +++ b/src/turbomind/models/llama/LlamaLinear.h @@ -31,6 +31,7 @@ class LlamaLinear { switch (weight.type) { case WeightType::kFP16: case WeightType::kFP32: + case WeightType::kBF16: forwardFp(output_data, input_data, batch_size, weight, type); break; case WeightType::kINT4: diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 12a3bc3cf5..c54f6d6d89 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -478,5 +478,8 @@ void LlamaV2::forward(std::unordered_map* outputs, template class LlamaV2; template class LlamaV2; +#ifdef ENABLE_BF16 +template class LlamaV2<__nv_bfloat16>; +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index e270d3ba5c..6e62eaf420 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -90,6 +90,9 @@ template void LlamaWeight::loadModel(std::string dir_path) { FtCudaDataType model_file_type = FtCudaDataType::FP16; + if(weight_type_ == WeightType::kBF16){ + model_file_type = FtCudaDataType::BF16; + } dir_path += '/'; loadWeightFromBin((T*)pre_decoder_embedding_table, @@ -140,5 +143,8 @@ TensorMap LlamaWeight::getParams() template struct LlamaWeight; template struct LlamaWeight; +#ifdef ENABLE_BF16 +template struct LlamaWeight<__nv_bfloat16>; +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/flash_attention2/CMakeLists.txt b/src/turbomind/models/llama/flash_attention2/CMakeLists.txt index 1a1fe37eaa..d41c391e9d 100644 --- a/src/turbomind/models/llama/flash_attention2/CMakeLists.txt +++ b/src/turbomind/models/llama/flash_attention2/CMakeLists.txt @@ -7,6 +7,7 @@ add_library(${PROJECT_NAME} STATIC # flash_fwd_hdim32_fp16_sm80.cu # flash_fwd_hdim64_fp16_sm80.cu flash_fwd_hdim128_fp16_sm80.cu + flash_fwd_hdim128_bf16_sm80.cu # flash_fwd_hdim256_fp16_sm80.cu ) target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include) diff --git a/src/turbomind/models/llama/flash_attention2/flash_api.cpp b/src/turbomind/models/llama/flash_attention2/flash_api.cpp index 55bc92c1ff..3fda17428c 100644 --- a/src/turbomind/models/llama/flash_attention2/flash_api.cpp +++ b/src/turbomind/models/llama/flash_attention2/flash_api.cpp @@ -12,7 +12,7 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { - FP16_SWITCH(true, + FP16_SWITCH(!params.is_bf16, [&] { FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_(params, stream); }); }); } @@ -126,7 +126,11 @@ class FlashAttentionOpImpl::impl { fwd_params.blockmask = reinterpret_cast(params.mask); - fwd_params.is_bf16 = false; +#ifdef ENABLE_BF16 + fwd_params.is_bf16 = std::is_same::value; +#else + fwd_params.is_bf16 = false; +#endif fwd_params.is_causal = true; fwd_params.q_enable_seqlen = params.layout_q.use_seqlens; @@ -163,5 +167,8 @@ void FlashAttentionOpImpl::operator()(Params& params, cudaStrea template class FlashAttentionOpImpl; template class FlashAttentionOpImpl; +#ifdef ENABLE_BF16 +template class FlashAttentionOpImpl<__nv_bfloat16, FMHA_VERSION>; +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim128_bf16_sm80.cu b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 0000000000..5b63fe086f --- /dev/null +++ b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +#ifdef ENABLE_BF16 +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim128(params, stream); +} +#endif diff --git a/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim256_bf16_sm80.cu b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 0000000000..00dbdc700b --- /dev/null +++ b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +#ifdef ENABLE_BF16 +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim256(params, stream); +} +#endif diff --git a/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim32_bf16_sm80.cu b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 0000000000..f3b2df8cd0 --- /dev/null +++ b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +#ifdef ENABLE_BF16 +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim32(params, stream); +} +#endif diff --git a/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim64_bf16_sm80.cu b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 0000000000..638d86cb71 --- /dev/null +++ b/src/turbomind/models/llama/flash_attention2/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +#ifdef ENABLE_BF16 +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim64(params, stream); +} +#endif diff --git a/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu b/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu index 4fae69bd08..12d70f6675 100644 --- a/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu +++ b/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu @@ -13,6 +13,27 @@ namespace turbomind { +template +struct ToCutlassType_ { +}; + +template<> +struct ToCutlassType_ { + using Type = float; +}; + +template<> +struct ToCutlassType_ { + using Type = cutlass::half_t; +}; + +#ifdef ENABLE_BF16 +template<> +struct ToCutlassType_<__nv_bfloat16> { + using Type = cutlass::bfloat16_t; +}; +#endif + template< // dtype of Q/K/V/M typename Element_, @@ -655,8 +676,7 @@ void invokeFlashAttention_impl(int batc auto layout_o = attention_params.layout_o; auto group_size = attention_params.group_size; - using scalar_t = - typename std::conditional_t::type>::value, cutlass::half_t, T>; + using scalar_t = typename ToCutlassType_::Type; const float qk_scale = static_cast(1.f / sqrtf(size_per_head * 1.f)); @@ -742,8 +762,7 @@ void invokeFlashAttention_impl(int batc template bool get_needs_accum_buffer() { - using scalar_t = - typename std::conditional_t::type>::value, cutlass::half_t, T>; + using scalar_t = typename ToCutlassType_::Type; #define GET_NEED_ACCUM_BUFFER(sm) \ ATTENTION_KERNEL(scalar_t, sm, kQueriesPerBlock, kKeysPerBlock, false)::kNeedsOutputAccumulatorBuffer @@ -774,8 +793,7 @@ void invoke_attention_impl(bool single_v typename FlashAttentionOpImpl::Params& params, cudaStream_t st) { - using scalar_t = - typename std::conditional_t::type>::value, cutlass::half_t, T>; + using scalar_t = typename ToCutlassType_::Type; #define INVOKE_ATTEN_IMPL(sm, single_value) \ { \ @@ -836,9 +854,8 @@ class FlashAttentionOpImpl::impl { private: static constexpr int kQueriesPerBlock = 32; static constexpr int kKeysPerBlock = 128; - using scalar_t = - typename std::conditional_t::type>::value, cutlass::half_t, T>; - using Params = typename FlashAttentionOpImpl::Params; + using scalar_t = typename ToCutlassType_::Type; + using Params = typename FlashAttentionOpImpl::Params; int batch_size_; int head_num_; @@ -909,5 +926,8 @@ void FlashAttentionOpImpl::operator()(Params& params, cudaStream_t st) con template class FlashAttentionOpImpl; template class FlashAttentionOpImpl; +#ifdef ENABLE_BF16 +template class FlashAttentionOpImpl<__nv_bfloat16, 1>; +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_decoder_kernels.cu b/src/turbomind/models/llama/llama_decoder_kernels.cu index 1fe1281af7..6bdfa2c5e6 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.cu +++ b/src/turbomind/models/llama/llama_decoder_kernels.cu @@ -2,6 +2,7 @@ #include "src/turbomind/macro.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h" +#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_utils.h" #include #include @@ -83,6 +84,32 @@ struct res_norm_ops_t { } }; +#ifdef ENABLE_BF16 +template<> +struct res_norm_ops_t<__nv_bfloat16> { + __device__ float2 cast(const uint& x) const + { + return cuda_cast(reinterpret_cast(x)); + } + __device__ uint cast(const float2& x) const + { + auto y = cuda_cast<__nv_bfloat162, float2>(x); + return reinterpret_cast(y); + } + __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const + { + float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y}; + accum += c.x * c.x + c.y * c.y; + return c; + } + __device__ float2 norm(const float2& a, const float2& s, float factor) const + { + return {a.x * s.x * factor, a.y * s.y * factor}; + } +}; + +#endif + template __device__ T blockReduceSum(const cg::thread_block& block, T value) { @@ -164,5 +191,8 @@ void invokeFusedAddBiasResidualRMSNorm( template void invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t); template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t); - +#ifdef ENABLE_BF16 +template void invokeFusedAddBiasResidualRMSNorm( + __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu index ff628dcced..1ba17f9f90 100644 --- a/src/turbomind/models/llama/llama_kernels.cu +++ b/src/turbomind/models/llama/llama_kernels.cu @@ -90,6 +90,10 @@ void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, template void invokeRootMeanSquareNorm(float*, const float*, const float*, float, int, int, cudaStream_t); template void invokeRootMeanSquareNorm(half*, const half*, const half*, float, int, int, cudaStream_t); +#ifdef ENABLE_BF16 +template void +invokeRootMeanSquareNorm(__nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); +#endif // #ifdef ENABLE_BF16 @@ -208,6 +212,23 @@ void invokeCreateCausalMasks( template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t); template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t); +#ifdef ENABLE_BF16 +template<> +__global__ void createCausalMasks<__nv_bfloat16>( + __nv_bfloat16* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len) +{ + const auto q_len = q_lens[blockIdx.x]; + const auto k_len = k_lens[blockIdx.x]; + mask += blockIdx.x * max_q_len * max_k_len; + for (int i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) { + const int q = i / max_k_len; // [0, max_q_len) + const int k = i % max_k_len; // [0, max_k_len) + bool is_valid = q < q_len && k < k_len && k <= q + (k_len - q_len); + mask[i] = static_cast<__nv_bfloat16>(float(is_valid)); + } +} +template void invokeCreateCausalMasks(__nv_bfloat16* mask, const int*, const int*, int, int, int, cudaStream_t); +#endif template struct ExtendKvCache { @@ -377,6 +398,24 @@ template void invokeExtendKVCache(void** k_dst_ptrs, int quant, const float* kv_scale, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeExtendKVCache(void** k_dst_ptrs, + void** v_dst_ptrs, + const __nv_bfloat16* k_src, + const __nv_bfloat16* v_src, + const int* cu_block_counts, + const int* query_length, + const int* history_length, + int batch_size, + int block_length, + size_t dst_layer_offset, + int max_q_len, + int head_dim, + int head_num, + int quant, + const float* kv_scale, + cudaStream_t stream); +#endif template struct TransposeKvCache { @@ -527,6 +566,23 @@ template void invokeTransposeKVCache(half*, cudaStream_t stream, int, const float*); +#ifdef ENABLE_BF16 +template void invokeTransposeKVCache(__nv_bfloat16*, + __nv_bfloat16*, + const __nv_bfloat16**, + const __nv_bfloat16**, + size_t, + int, + const int*, + int, + int, + int, + int, + int, + cudaStream_t stream, + int, + const float*); +#endif __global__ void gatherOutput(int* output_ids, const int* ids, @@ -776,6 +832,9 @@ void invokeGetFeatureOfLastToken( template void invokeGetFeatureOfLastToken(half*, const half*, const int*, int, int, cudaStream_t); template void invokeGetFeatureOfLastToken(float*, const float*, const int*, int, int, cudaStream_t); +#ifdef ENABLE_BF16 +template void invokeGetFeatureOfLastToken(__nv_bfloat16*, const __nv_bfloat16*, const int*, int, int, cudaStream_t); +#endif // ENABLE_BF16 template struct BatchedCopyParam { @@ -866,7 +925,7 @@ FlashAttentionOp::FlashAttentionOp(int batch_size, int head_num, int key_len, #ifdef _MSC_VER op_version_ = 1; #else - op_version_ = std::is_same::type>::value ? 2 : 1; + op_version_ = std::is_same::type>::value ? 1 : 2; if (op_version_ == 2 && getSMVersion() < 80) { op_version_ = 1; } @@ -903,5 +962,8 @@ void FlashAttentionOp::operator()(Params& params, cudaStream_t st) const template class FlashAttentionOp; template class FlashAttentionOp; +#ifdef ENABLE_BF16 +template class FlashAttentionOp<__nv_bfloat16>; +#endif } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index d9ae5d0be6..aeb8c5db48 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -626,5 +626,8 @@ void UnifiedAttentionLayer::unfusedMultiHeadAttention(T* output, template class UnifiedAttentionLayer; template class UnifiedAttentionLayer; +#ifdef ENABLE_BF16 +template class UnifiedAttentionLayer<__nv_bfloat16>; +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 20974eeea9..311fc6b764 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -253,5 +253,8 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con template class UnifiedDecoder; template class UnifiedDecoder; +#ifdef ENABLE_BF16 +template class UnifiedDecoder<__nv_bfloat16>; +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 46e8443a86..85a3c8af83 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace py = pybind11; namespace ft = turbomind; @@ -23,12 +24,6 @@ using TensorMap = std::unordered_map; PYBIND11_MAKE_OPAQUE(TensorMap); static const char kDlTensorCapsuleName[] = "dltensor"; -template -std::shared_ptr make_shared_nodel(T data) -{ - return std::shared_ptr(&data, [](T*) {}); -} - DLDevice getDLDevice(triton::Tensor& tensor) { int device_id = 0; @@ -46,6 +41,7 @@ DLDevice getDLDevice(triton::Tensor& tensor) break; case triton::MEMORY_CPU_PINNED: device.device_type = DLDeviceType::kDLCUDAHost; + break; case triton::MEMORY_GPU: device.device_type = DLDeviceType::kDLCUDA; break; @@ -132,12 +128,11 @@ std::unique_ptr TritonTensorToDLManagedTensor(triton::Tensor& t triton::MemoryType getMemoryType(DLDevice device) { switch (device.device_type) { - case DLDeviceType::kDLCPU: - return triton::MemoryType::MEMORY_CPU; case DLDeviceType::kDLCUDAHost: return triton::MemoryType::MEMORY_CPU_PINNED; case DLDeviceType::kDLCUDA: return triton::MemoryType::MEMORY_GPU; + case DLDeviceType::kDLCPU: default: return triton::MemoryType::MEMORY_CPU; } @@ -289,17 +284,21 @@ PYBIND11_MODULE(_turbomind, m) DLManagedTensor* dlmt = static_cast(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName)); auto src = DLManagedTensorToTritonTensor(dlmt); - if (self->type == triton::TYPE_FP16 || self->type == triton::TYPE_FP32 - || self->type == triton::TYPE_INT32) { - auto num_element = - std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies()); - auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8; - ft::FT_CHECK(self->shape.size() == 1 && num_bytes == self->shape[0]); - cudaMemcpy( - const_cast(self->data), const_cast(src->data), num_bytes, cudaMemcpyDefault); - } - else { - ft::FT_CHECK(0); + switch (self->type) { + case triton::TYPE_FP16: + case triton::TYPE_FP32: + case triton::TYPE_INT32: + case triton::TYPE_BF16: { + auto num_element = + std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies()); + auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8; + ft::FT_CHECK(self->shape.size() == 1 && num_bytes == self->shape[0]); + cudaMemcpy( + const_cast(self->data), const_cast(src->data), num_bytes, cudaMemcpyDefault); + break; + } + default: + ft::FT_CHECK(0); } }, "tensor"_a) @@ -380,6 +379,16 @@ PYBIND11_MODULE(_turbomind, m) model->setFfiLock(gil_control); return model; } + else if (data_type == "bf16") { +#ifdef ENABLE_BF16 + auto model = std::make_shared>( + tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); + model->setFfiLock(gil_control); + return model; +#else + throw std::runtime_error("Error: turbomind has not been built with bf16 support."); +#endif + } else { auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 33711b502d..6127fdf2db 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -47,6 +47,18 @@ std::shared_ptr AbstractTransformerModel::createLlamaM reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), model_dir); } + else if (data_type == "bf16") { +#ifdef ENABLE_BF16 + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir); +#else + TM_LOG_ERROR("[ERROR] Turbomind is not built with ENABLE_BF16"); + ft::FT_CHECK(false); +#endif + } else { return std::make_shared>( reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), @@ -205,6 +217,9 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, if (weight_type_str == "fp16") { weight_type_ = ft::WeightType::kFP16; } + else if (weight_type_str == "bf16") { + weight_type_ = ft::WeightType::kBF16; + } else if (weight_type_str == "fp32") { weight_type_ = ft::WeightType::kFP32; } @@ -260,6 +275,11 @@ std::unique_ptr> LlamaTritonModel::createSh else if (std::is_same::value) { cublas_wrapper->setFP32GemmConfig(); } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif ft::NcclParam tensor_para = nccl_params.first[comms_rank]; ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; @@ -449,3 +469,6 @@ int LlamaTritonModel::getPipelineParaSize() template struct LlamaTritonModel; template struct LlamaTritonModel; +#ifdef ENABLE_BF16 +template struct LlamaTritonModel<__nv_bfloat16>; +#endif diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc index b4666bd1e7..5513681399 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc @@ -244,3 +244,6 @@ void LlamaTritonModelInstance::freeBuffer() template struct LlamaTritonModelInstance; template struct LlamaTritonModelInstance; +#ifdef ENABLE_BF16 +template struct LlamaTritonModelInstance<__nv_bfloat16>; +#endif diff --git a/src/turbomind/utils/cuda_type_utils.cuh b/src/turbomind/utils/cuda_type_utils.cuh index b1b9c40e87..f7f7b95273 100644 --- a/src/turbomind/utils/cuda_type_utils.cuh +++ b/src/turbomind/utils/cuda_type_utils.cuh @@ -507,12 +507,12 @@ __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) template<> __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) { - return fabs(val); + return fabs(cuda_cast(val)); } template<> __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) { - return make_bfloat162(fabs(val.x), fabs(val.y)); + return make_bfloat162(fabs(cuda_cast(val.x)), fabs(cuda_cast(val.y))); } #endif