Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable attention mask when it is not needed #813

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,8 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i
session_len_ = max_session_len;
}

FT_CHECK(max_context_token_num_ >= session_len_);

for (auto& s : states_) {
s.requests.resize(max_batch_size_);
s.sequences.resize(max_batch_size_);
Expand Down
28 changes: 18 additions & 10 deletions src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ void UnifiedDecoder<T>::allocateBuffer(size_t num_token, size_t pf_batch_size, s
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

if (pf_batch_size) {
attention_mask_ =
(T*)allocator_->reMalloc(attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false);
if (need_causal_mask_) {
attention_mask_ = (T*)allocator_->reMalloc(
attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false);
}
else {
// just to avoid nullptr
attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T), false);
}
padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * pf_batch_size * pf_max_q_len, false);
cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (pf_batch_size + 1), false);
Expand Down Expand Up @@ -162,14 +168,16 @@ void UnifiedDecoder<T>::forward(TensorMap* outputs, const TensorMap* inputs, con

FT_CHECK(tmp_token_num == token_num - dc_batch_size);

invokeCreateCausalMasks(attention_mask_,
input_length + pf_offset,
context_length + pf_offset,
pf_max_q_len,
pf_max_k_len,
pf_batch_size,
stream_);
sync_check_cuda_error();
if (need_causal_mask_) {
invokeCreateCausalMasks(attention_mask_,
input_length + pf_offset,
context_length + pf_offset,
pf_max_q_len,
pf_max_k_len,
pf_batch_size,
stream_);
sync_check_cuda_error();
}
}

/////////////////////////////////////////////
Expand Down
11 changes: 11 additions & 0 deletions src/turbomind/models/llama/unified_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/unified_attention_layer.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h"

namespace turbomind {
Expand Down Expand Up @@ -46,6 +47,8 @@ class UnifiedDecoder {

const DataType dtype_;

bool need_causal_mask_{false};

using WeightType = LlamaDecoderLayerWeight<T>;

void forwardSelfAttn(T* attn_io,
Expand Down Expand Up @@ -88,6 +91,14 @@ class UnifiedDecoder {
tensor_para_(tensor_para),
dtype_(getTensorType<T>())
{
#ifdef _MSC_VER
// Both unfused MHA and flash attention 1 need causal mask
need_causal_mask_ = true;
#endif
// attention mask is not used for FA-1 (which requires sm80+ and half/bf16 data type)
if (!use_fmha || (getSMVersion() < 80 || sizeof(T) != 2)) {
need_causal_mask_ = true;
}
initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy);
}

Expand Down
Loading