From e7886b43e6097514700c124e262246ec39205004 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Wed, 6 Nov 2024 15:27:46 +0800 Subject: [PATCH] support turbomind head_dim 64 (#2715) * support head_dim 64 * fix unit-test * fix wrong dispatch * fix comments * fix comments --- .../turbomind/deploy/source_model/internvl.py | 5 +++ .../turbomind/deploy/source_model/llama.py | 2 + lmdeploy/turbomind/supported_models.py | 12 ++---- .../kernels/attention/CMakeLists.txt | 16 ++++++++ src/turbomind/kernels/attention/attention.cu | 19 ++++++--- .../codegen/attention_sm70_64_f16.cu | 16 ++++++++ .../codegen/attention_sm75_64_f16.cu | 17 ++++++++ .../codegen/attention_sm80_64_bf16.cu | 16 ++++++++ .../codegen/attention_sm80_64_f16.cu | 16 ++++++++ .../codegen/decoding_sm70_64_f16_f16.cu | 16 ++++++++ .../codegen/decoding_sm70_64_f16_u4.cu | 17 ++++++++ .../codegen/decoding_sm70_64_f16_u8.cu | 17 ++++++++ .../codegen/decoding_sm75_64_f16_f16.cu | 14 +++++++ .../codegen/decoding_sm75_64_f16_u4.cu | 14 +++++++ .../codegen/decoding_sm75_64_f16_u8.cu | 14 +++++++ .../codegen/decoding_sm80_64_bf16_bf16.cu | 22 ++++++++++ .../codegen/decoding_sm80_64_bf16_u4.cu | 14 +++++++ .../codegen/decoding_sm80_64_bf16_u8.cu | 14 +++++++ .../codegen/decoding_sm80_64_f16_f16.cu | 18 +++++++++ .../codegen/decoding_sm80_64_f16_u4.cu | 14 +++++++ .../codegen/decoding_sm80_64_f16_u8.cu | 14 +++++++ src/turbomind/kernels/attention/decoding.cu | 33 +++++++++------ src/turbomind/kernels/attention/impl_16816.h | 17 ++++---- src/turbomind/kernels/attention/impl_1688.h | 12 ++++-- src/turbomind/kernels/attention/impl_81616.h | 4 +- .../kernels/attention/kv_cache_utils_v2.cu | 40 ++++++++++++++----- src/turbomind/kernels/attention/reduce.cu | 39 ++++++++---------- tests/test_lmdeploy/test_auto_backend.py | 2 +- 28 files changed, 383 insertions(+), 71 deletions(-) create mode 100644 src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu create mode 100644 src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu create mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu create mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu diff --git a/lmdeploy/turbomind/deploy/source_model/internvl.py b/lmdeploy/turbomind/deploy/source_model/internvl.py index 51082fb3a1..bb660a59b2 100644 --- a/lmdeploy/turbomind/deploy/source_model/internvl.py +++ b/lmdeploy/turbomind/deploy/source_model/internvl.py @@ -80,8 +80,13 @@ def model_info(self): scaling_factor = model_arg['rope_scaling'].get('factor', '') if scaling_type == 'dynamic': use_dynamic_ntk = 1 + attn_bias = 1 if model_arg['architectures'][ + 0] == 'Qwen2ForCausalLM' else 0 return dict(num_layer=num_layer, + size_per_head=hidden_units // attn_head_num, + rotary_embedding=hidden_units // attn_head_num, + attn_bias=attn_bias, norm_eps=norm_eps, hidden_units=hidden_units, inter_size=inter_size, diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py index 8e19fa8d87..a8aa51b144 100644 --- a/lmdeploy/turbomind/deploy/source_model/llama.py +++ b/lmdeploy/turbomind/deploy/source_model/llama.py @@ -189,6 +189,8 @@ def model_info(self): beta_slow = rope_scaling.get('beta_slow', 1.0) return dict( + size_per_head=hidden_units // attn_head_num, + rotary_embedding=hidden_units // attn_head_num, num_layer=num_layer, norm_eps=norm_eps, head_num=attn_head_num, diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 8a1f5e7315..979ed0c547 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -65,11 +65,10 @@ def is_supported(model_path: str): """ # noqa: E501 import os - def _is_head_dim_128(cfg): + def _is_head_dim_supported(cfg): num_attn_head = cfg.num_attention_heads hidden_size = cfg.hidden_size - # turbomind support head_dim=128 - return (hidden_size // num_attn_head) == 128 + return (hidden_size // num_attn_head) in [128, 64] support_by_turbomind = False triton_model_path = os.path.join(model_path, 'triton_models') @@ -87,9 +86,7 @@ def _is_head_dim_128(cfg): # baichuan-13B, baichuan2-13B not supported by turbomind support_by_turbomind = False elif arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']: - # the head_dim of qwen2 0.5b and llama3.2-1b is 64, which - # hasn't been supported by turbomind yet - support_by_turbomind = _is_head_dim_128(cfg) + support_by_turbomind = _is_head_dim_supported(cfg) elif arch in ('ChatGLMModel', 'ChatGLMForConditionalGeneration'): # chatglm1/2/3 is not working yet support_by_turbomind = cfg.num_layers == 40 @@ -97,7 +94,6 @@ def _is_head_dim_128(cfg): # glm-4v-9b not supported support_by_turbomind = False elif arch == 'InternVLChatModel': - # internvl2-4b,internlm2-1b are not working yet - support_by_turbomind = _is_head_dim_128(cfg.llm_config) + support_by_turbomind = _is_head_dim_supported(cfg.llm_config) return support_by_turbomind diff --git a/src/turbomind/kernels/attention/CMakeLists.txt b/src/turbomind/kernels/attention/CMakeLists.txt index 4ca63f5db6..af9d47e0e6 100644 --- a/src/turbomind/kernels/attention/CMakeLists.txt +++ b/src/turbomind/kernels/attention/CMakeLists.txt @@ -22,6 +22,22 @@ add_library(attention STATIC codegen/decoding_sm80_128_f16_f16.cu codegen/decoding_sm80_128_f16_u4.cu codegen/decoding_sm80_128_f16_u8.cu + codegen/attention_sm70_64_f16.cu + codegen/attention_sm75_64_f16.cu + codegen/attention_sm80_64_bf16.cu + codegen/attention_sm80_64_f16.cu + codegen/decoding_sm70_64_f16_f16.cu + codegen/decoding_sm70_64_f16_u4.cu + codegen/decoding_sm70_64_f16_u8.cu + codegen/decoding_sm75_64_f16_f16.cu + codegen/decoding_sm75_64_f16_u4.cu + codegen/decoding_sm75_64_f16_u8.cu + codegen/decoding_sm80_64_bf16_bf16.cu + codegen/decoding_sm80_64_bf16_u4.cu + codegen/decoding_sm80_64_bf16_u8.cu + codegen/decoding_sm80_64_f16_f16.cu + codegen/decoding_sm80_64_f16_u4.cu + codegen/decoding_sm80_64_f16_u8.cu ) set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/kernels/attention/attention.cu b/src/turbomind/kernels/attention/attention.cu index ffbad56b46..3f557234bc 100644 --- a/src/turbomind/kernels/attention/attention.cu +++ b/src/turbomind/kernels/attention/attention.cu @@ -14,20 +14,19 @@ template void dispatchAttention(const AttentionParams& params) { using namespace attention; - if (params.size_per_head == 128) { - + auto dispatch = [&](const auto dim) { + constexpr int kHeadDim = dim; if (params.arch >= 80) { - using Config = AttentionConfig; + using Config = AttentionConfig; return invokeAttention(params); } - if constexpr (!std::is_same_v) { if (params.arch == 75) { - return invokeAttention::Kernel>( + return invokeAttention::Kernel>( params); } else if (params.arch >= 70) { - return invokeAttention::Kernel>( + return invokeAttention::Kernel>( params); } } @@ -38,6 +37,14 @@ void dispatchAttention(const AttentionParams& params) params.arch); } } + FT_CHECK(0); + }; + + if (params.size_per_head == 64) { + return dispatch(std::integral_constant{}); + } + else if (params.size_per_head == 128) { + return dispatch(std::integral_constant{}); } FT_CHECK(0); } diff --git a/src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu new file mode 100644 index 0000000000..72b219432c --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu @@ -0,0 +1,16 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../attention_config.h" +#include "../attention_template.h" + +namespace turbomind { + +using namespace attention; + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu new file mode 100644 index 0000000000..cef945015a --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu @@ -0,0 +1,17 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../attention_config.h" +#include "../attention_template.h" + +namespace turbomind { + +using namespace attention; + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +// ! register spill +// template void invokeAttention::Kernel>( +// const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu new file mode 100644 index 0000000000..cc6e54c14b --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu @@ -0,0 +1,16 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../attention_config.h" +#include "../attention_template.h" + +namespace turbomind { + +using namespace attention; + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu new file mode 100644 index 0000000000..26e3f54b29 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu @@ -0,0 +1,16 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../attention_config.h" +#include "../attention_template.h" + +namespace turbomind { + +using namespace attention; + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu new file mode 100644 index 0000000000..12558aeae6 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu @@ -0,0 +1,16 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu new file mode 100644 index 0000000000..25b49f9590 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu @@ -0,0 +1,17 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../attention_params.h" +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu new file mode 100644 index 0000000000..824cd5b02e --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu @@ -0,0 +1,17 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../attention_params.h" +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu new file mode 100644 index 0000000000..456e6e18d7 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu new file mode 100644 index 0000000000..171e59f5f1 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu new file mode 100644 index 0000000000..1d6d40ed3a --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu new file mode 100644 index 0000000000..b657034c4c --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu @@ -0,0 +1,22 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool +invokeDecoding>(const AttentionParams& params); + +template bool +invokeDecoding>(const AttentionParams& params); + +template bool +invokeDecoding>(const AttentionParams& params); + +template bool +invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu new file mode 100644 index 0000000000..a5c0b34b7f --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams&); + +template bool invokeDecoding>(const AttentionParams&); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu new file mode 100644 index 0000000000..a7dd3050b1 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams&); + +template bool invokeDecoding>(const AttentionParams&); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu new file mode 100644 index 0000000000..e73be11e62 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu @@ -0,0 +1,18 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu new file mode 100644 index 0000000000..c7c560e98d --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams&); + +template bool invokeDecoding>(const AttentionParams&); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu new file mode 100644 index 0000000000..06f6ce5600 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool invokeDecoding>(const AttentionParams&); + +template bool invokeDecoding>(const AttentionParams&); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/decoding.cu b/src/turbomind/kernels/attention/decoding.cu index 2b9328a681..1b04b7d4eb 100644 --- a/src/turbomind/kernels/attention/decoding.cu +++ b/src/turbomind/kernels/attention/decoding.cu @@ -29,8 +29,6 @@ constexpr auto get_kv_type(std::integral_constant) template void dispatchDecoding(const AttentionParams& params) { - static constexpr std::integral_constant kHeadDim{}; - const bool is_kv_int8 = params.quant_policy & QuantPolicy::kCacheKVInt8; const bool is_kv_int4 = params.quant_policy & QuantPolicy::kCacheKVInt4; const int query_group_sz = params.num_heads / params.num_kv_heads; @@ -39,9 +37,10 @@ void dispatchDecoding(const AttentionParams& params) /// TODO: we need better Qh dispatching, when #waves < 1, smaller Qh may outperform larger Qh due to better // concurrency - auto dispatch_h = [&](auto arch, auto kv) -> bool { - using Arch = decltype(arch); - using Tkv = decltype(kv); + auto dispatch_h = [&](auto arch, auto kv, const auto dim) -> bool { + using Arch = decltype(arch); + using Tkv = decltype(kv); + constexpr int kHeadDim = dim; if (0) {} else if (query_group_sz > 8) { return invokeDecoding>(params); @@ -73,31 +72,41 @@ void dispatchDecoding(const AttentionParams& params) return false; }; - auto dispatch_kv = [&](auto arch) -> bool { + auto dispatch_kv = [&](auto arch, const auto dim) -> bool { FT_CHECK(!(is_kv_int4 && is_kv_int8)); if (is_kv_int4) { - return dispatch_h(arch, uint4_t{}); + return dispatch_h(arch, uint4_t{}, dim); } else if (is_kv_int8) { - return dispatch_h(arch, uint8_t{}); + return dispatch_h(arch, uint8_t{}, dim); } else { - return dispatch_h(arch, T{}); + return dispatch_h(arch, T{}, dim); + } + return false; + }; + + auto dispatch_head_dim = [&](auto arch) { + if (params.size_per_head == 128) { + return dispatch_kv(arch, std::integral_constant{}); + } + else if (params.size_per_head == 64) { + return dispatch_kv(arch, std::integral_constant{}); } return false; }; auto dispatch = [&]() { if (params.arch >= 80) { - return dispatch_kv(arch::Sm80{}); + return dispatch_head_dim(arch::Sm80{}); } if constexpr (!std::is_same_v) { if (params.arch == 75) { - return dispatch_kv(arch::Sm75{}); + return dispatch_head_dim(arch::Sm75{}); } else if (params.arch >= 70) { - return dispatch_kv(arch::Sm70{}); + return dispatch_head_dim(arch::Sm70{}); } } diff --git a/src/turbomind/kernels/attention/impl_16816.h b/src/turbomind/kernels/attention/impl_16816.h index 69e0a6a48c..6e8f37f4d4 100644 --- a/src/turbomind/kernels/attention/impl_16816.h +++ b/src/turbomind/kernels/attention/impl_16816.h @@ -63,14 +63,15 @@ struct Impl>; -#if 0 - using SmemLayoutK = SmemLayoutV2>; - using SmemLayoutV = SmemLayoutV2>; -#else - using SmemLayoutK = SmemLayoutV2>; - using SmemLayoutV = SmemLayoutV2>; -#endif + using SmemLayoutQ = std::conditional_t>, + SmemLayoutV2>>; + using SmemLayoutK = std::conditional_t>, + SmemLayoutV2>>; + using SmemLayoutV = std::conditional_t>, + SmemLayoutV2>>; using SmemLayoutKVp = void; diff --git a/src/turbomind/kernels/attention/impl_1688.h b/src/turbomind/kernels/attention/impl_1688.h index 856ddcd587..a822c58039 100644 --- a/src/turbomind/kernels/attention/impl_1688.h +++ b/src/turbomind/kernels/attention/impl_1688.h @@ -61,9 +61,15 @@ struct Impl[V_K][V_N]; // ((d8, s4), (Sk, Dn), (s2)) // 1 2 8 8 1 - using SmemLayoutQ = SmemLayoutV2>; - using SmemLayoutK = SmemLayoutV2>; // load by (s32,d8) tile - using SmemLayoutV = SmemLayoutV2>; // load by (s8,d32) tile + using SmemLayoutQ = std::conditional_t>, + SmemLayoutV2>>; + using SmemLayoutK = std::conditional_t>, + SmemLayoutV2>>; + using SmemLayoutV = std::conditional_t>, + SmemLayoutV2>>; using SmemLayoutKVp = void; diff --git a/src/turbomind/kernels/attention/impl_81616.h b/src/turbomind/kernels/attention/impl_81616.h index 0c0baa531a..3b90bcdf57 100644 --- a/src/turbomind/kernels/attention/impl_81616.h +++ b/src/turbomind/kernels/attention/impl_81616.h @@ -104,7 +104,9 @@ struct Impl) { - return SmemLayoutV2>{}; + return std::conditional_t>, + SmemLayoutV2>>{}; } using SmemLayoutQ = SmemLayoutV2>; diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 9f28a17b83..20bb00fde8 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -241,10 +241,10 @@ void invokeProcessKV_v2(char** blocks, int block = WARPS * WARP_SIZE; dim3 grid((max_q_len + CTA_S - 1) / CTA_S, head_num, batch_size); - auto invoke = [&](auto tkv) { + auto invoke = [&](auto tkv, const auto dim) { using Tkv = decltype(tkv); - constexpr int kHeadDim = 128; + constexpr int kHeadDim = dim; FT_CHECK(head_dim == kHeadDim); block::Layout block_layout{block::Config{head_num, block_seq_len}}; @@ -276,14 +276,24 @@ void invokeProcessKV_v2(char** blocks, block_layout); }; + auto dispatch = [&](auto tkv) { + if (head_dim == 128) { + return invoke(tkv, std::integral_constant{}); + } + else if (head_dim == 64) { + return invoke(tkv, std::integral_constant{}); + } + FT_CHECK(0); + }; + if (quant_policy & QuantPolicy::kCacheKVInt8) { - invoke(uint8_t{}); + dispatch(uint8_t{}); } else if (quant_policy & QuantPolicy::kCacheKVInt4) { - invoke(uint4_t{}); + dispatch(uint4_t{}); } else { - invoke(T{}); + dispatch(T{}); } } @@ -496,10 +506,10 @@ void invokeFlattenKV_v2(T* k, constexpr int block = kWarpCnt * WARP_SIZE; const dim3 grid((max_seq_len + CTA_S - 1) / CTA_S, head_num, batch_size); - auto invoke = [&](auto tkv) { + auto invoke = [&](auto tkv, const auto dim) { using Tkv = decltype(tkv); - constexpr int kHeadDim = 128; + constexpr int kHeadDim = dim; FT_CHECK(head_dim == kHeadDim); block::Layout block_layout{block::Config{head_num, block_seq_len}}; @@ -528,14 +538,24 @@ void invokeFlattenKV_v2(T* k, block_layout); }; + auto dispatch = [&](auto tkv) { + if (head_dim == 64) { + return invoke(tkv, std::integral_constant{}); + } + else if (head_dim == 128) { + return invoke(tkv, std::integral_constant{}); + } + FT_CHECK(0); + }; + if (quant_policy & QuantPolicy::kCacheKVInt8) { - invoke(uint8_t{}); + dispatch(uint8_t{}); } else if (quant_policy & QuantPolicy::kCacheKVInt4) { - invoke(uint4_t{}); + dispatch(uint4_t{}); } else { - invoke(T{}); + dispatch(T{}); } } diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index 44b3dbfdaa..12f6aff38b 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -53,30 +53,25 @@ void invokeReduce(T* out, invoke(std::true_type{}, stride_k); } -template void invokeReduce<128>(half* out, - float* partial_M, - float* partial_L, - float* partial_O, - const int* split_cnt, - int partial_len, - int max_split_cnt, - int query_num, - int head_num, - float exp_scale, - cudaStream_t stream); +#define INSTANTIATE_invokeReduce(dim, type) \ + template void invokeReduce(type * out, \ + float* partial_M, \ + float* partial_L, \ + float* partial_O, \ + const int* split_cnt, \ + int partial_len, \ + int max_split_cnt, \ + int query_num, \ + int head_num, \ + float exp_scale, \ + cudaStream_t stream); + +INSTANTIATE_invokeReduce(128, half); +INSTANTIATE_invokeReduce(64, half); #if ENABLE_BF16 -template void invokeReduce<128>(nv_bfloat16* out, - float* partial_M, - float* partial_L, - float* partial_O, - const int* split_cnt, - int partial_len, - int max_split_cnt, - int query_num, - int head_num, - float exp_scale, - cudaStream_t stream); +INSTANTIATE_invokeReduce(128, nv_bfloat16); +INSTANTIATE_invokeReduce(64, nv_bfloat16) #endif } // namespace turbomind::attention diff --git a/tests/test_lmdeploy/test_auto_backend.py b/tests/test_lmdeploy/test_auto_backend.py index 3dfcac292a..5db727f17f 100644 --- a/tests/test_lmdeploy/test_auto_backend.py +++ b/tests/test_lmdeploy/test_auto_backend.py @@ -38,7 +38,7 @@ def models(self): ('Qwen/Qwen-7B-Chat', True, True), ('Qwen/Qwen-VL-Chat', False, True), ('Qwen/Qwen1.5-4B-Chat', True, True), - ('Qwen/Qwen1.5-0.5B-Chat', True, False), + ('Qwen/Qwen1.5-0.5B-Chat', True, True), ] return models