diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 6690a05634..ed713fe5d8 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -894,6 +894,8 @@ LlamaBatch::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i session_len_ = max_session_len; } + FT_CHECK(max_context_token_num_ >= session_len_); + for (auto& s : states_) { s.requests.resize(max_batch_size_); s.sequences.resize(max_batch_size_); diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 20974eeea9..358f5c04f6 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -15,8 +15,14 @@ void UnifiedDecoder::allocateBuffer(size_t num_token, size_t pf_batch_size, s TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (pf_batch_size) { - attention_mask_ = - (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false); + if (need_causal_mask_) { + attention_mask_ = (T*)allocator_->reMalloc( + attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false); + } + else { + // just to avoid nullptr + attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T), false); + } padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * pf_batch_size * pf_max_q_len, false); cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (pf_batch_size + 1), false); @@ -162,14 +168,16 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con FT_CHECK(tmp_token_num == token_num - dc_batch_size); - invokeCreateCausalMasks(attention_mask_, - input_length + pf_offset, - context_length + pf_offset, - pf_max_q_len, - pf_max_k_len, - pf_batch_size, - stream_); - sync_check_cuda_error(); + if (need_causal_mask_) { + invokeCreateCausalMasks(attention_mask_, + input_length + pf_offset, + context_length + pf_offset, + pf_max_q_len, + pf_max_k_len, + pf_batch_size, + stream_); + sync_check_cuda_error(); + } } ///////////////////////////////////////////// diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index daac2b4df6..533976f947 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -5,6 +5,7 @@ #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/unified_attention_layer.h" #include "src/turbomind/utils/cublasMMWrapper.h" +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/nccl_utils.h" namespace turbomind { @@ -46,6 +47,8 @@ class UnifiedDecoder { const DataType dtype_; + bool need_causal_mask_{false}; + using WeightType = LlamaDecoderLayerWeight; void forwardSelfAttn(T* attn_io, @@ -88,6 +91,14 @@ class UnifiedDecoder { tensor_para_(tensor_para), dtype_(getTensorType()) { +#ifdef _MSC_VER + // Both unfused MHA and flash attention 1 need causal mask + need_causal_mask_ = true; +#endif + // attention mask is not used for FA-1 (which requires sm80+ and half/bf16 data type) + if (!use_fmha || (getSMVersion() < 80 || sizeof(T) != 2)) { + need_causal_mask_ = true; + } initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy); }