Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support min_length for turbomind backend #961

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class GenerationConfig:
random_seed (int): Seed used when sampling a token
stop_words (List[str]): Words that stop generating further tokens
bad_words (List[str]): Words that the engine will never generate
min_new_tokens (int): The minimum numbers of tokens to generate,
ignoring the number of tokens in the prompt.
"""

n: int = 1
Expand All @@ -42,6 +44,7 @@ class GenerationConfig:
random_seed: int = None
stop_words: List[str] = None
bad_words: List[str] = None
min_new_tokens: int = None


@dataclass
Expand All @@ -65,7 +68,7 @@ def From(gen_config: GenerationConfig, tokenizer: Tokenizer):
>>> tokenizer = Tokenizer('internlm/internlm-chat-7b')
>>> gen_config = GenerationConfig(stop_words=['<eoa>'])
>>> gen_config = EngineGenerationConfig.From(gen_config, tokenizer)
""" # noqa E501
""" # noqa E501

def special_word_token_ids(words):
if words is not None:
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,10 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
inputs['input_embeddings'] = input_embeddings
inputs['input_embedding_ranges'] = input_embedding_ranges

if gen_config.min_new_tokens is not None:
inputs['min_length'] = _broadcast_np(gen_config.min_new_tokens,
np.int32)

bad_words = []
if gen_config.bad_words is not None:
bad_words.extend(gen_config.bad_words)
Expand Down
5 changes: 2 additions & 3 deletions src/turbomind/kernels/sampling_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,8 @@ __global__ void batchApplyMinLengthPenalty(T* logits,
const int vocab_size_padded)
{
int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index
// We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1,
// which is equal to the length of k/v caches.
if (sequence_lengths[bid] + 1 - max_input_length < min_lengths[bid]) {
// In decoder, sequence_lengths means length of sequence that has kv cache already computed
if (sequence_lengths[bid] + 1 < min_lengths[bid]) {
T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
logits[bid * vocab_size_padded + end_ids[bid]] = mask_val;
}
Expand Down
33 changes: 22 additions & 11 deletions src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
repetition_penalty_ = (float*)std::realloc((void*)repetition_penalty_, batch_size * sizeof(float));
min_lengths_ = (int*)std::realloc((void*)min_lengths_, batch_size * sizeof(int));
skip_decode_ = (bool*)std::realloc((void*)skip_decode_, batch_size * sizeof(bool));
context_length_ = (int*)std::realloc((void*)context_length_, batch_size * sizeof(int));

is_allocate_buffer_ = true;
}
Expand All @@ -63,6 +64,7 @@ void BaseSamplingLayer<T>::freeBuffer()
std::free(repetition_penalty_);
std::free(min_lengths_);
std::free(skip_decode_);
std::free(context_length_);
is_allocate_buffer_ = false;
}
}
Expand Down Expand Up @@ -161,16 +163,23 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
repetition_penalty_type_ = RepetitionPenaltyType::None;
}

const int default_min_length = 0;
Tensor min_lengths = runtime_args->at("min_length", Tensor(MEMORY_CPU, TYPE_INT32, {1}, &default_min_length));
if (min_lengths.size() == 1) {
int minlen = min_lengths.getVal<int>();
deviceFill(min_lengths_buf_, batch_size, minlen, stream_);
std::fill_n(min_lengths_, batch_size, minlen);
// min_length
if (runtime_args->isExist("min_length")) {
Tensor min_lengths = runtime_args->at("min_length");
Tensor context_lengths = runtime_args->at("context_length");
Tensor prompt_lengths = runtime_args->at("prompt_length");
auto p1 = min_lengths.getPtr<int>();
auto p2 = prompt_lengths.getPtr<int>();
for (int i = 0; i < batch_size; i++) {
min_lengths_[i] = p1[i] + p2[i];
}
cudaAutoCpy(min_lengths_buf_, min_lengths_, batch_size, stream_);
std::copy_n(context_lengths.getPtr<int>(), batch_size, context_length_);
}
else {
cudaAutoCpy(min_lengths_buf_, min_lengths.getPtr<int>(), batch_size, stream_);
std::copy_n(min_lengths.getPtr<int>(), batch_size, min_lengths_);
std::fill_n(min_lengths_, batch_size, 0);
deviceFill(min_lengths_buf_, batch_size, 0, stream_);
std::fill_n(context_length_, batch_size, 0);
}
}

Expand Down Expand Up @@ -300,10 +309,12 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t
}
}

const int num_generated_tokens = step - max_input_length;
const int* min_lengths = min_lengths_ + ite * local_batch_size;
const int num_generated_tokens = step - max_input_length;
const int* min_lengths = min_lengths_ + ite * local_batch_size;
std::vector<int> index(local_batch_size);
std::iota(index.begin(), index.end(), 0);
const bool invoke_min_length_penalty = std::any_of(
min_lengths, min_lengths + local_batch_size, [&](int min_length) { return min_length > num_generated_tokens; });
index.begin(), index.end(), [&](int i) { return min_lengths[i] > context_length_[i] + num_generated_tokens; });
if (invoke_min_length_penalty) {
FT_CHECK_WITH_INFO(input_tensors->isExist("end_id"), "Need end_id to apply min length penlaty");
invokeMinLengthPenalty(logits,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/layers/sampling_layers/BaseSamplingLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer {
int* min_lengths_ = nullptr;
bool* skip_decode_ = nullptr;
bool skip_any_ = false;
int* context_length_ = nullptr;

RepetitionPenaltyType repetition_penalty_type_ = RepetitionPenaltyType::None;

Expand Down
12 changes: 12 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
}

// total context length (history + input)
state.h_prompt_length[idx] = output_ids - output_ids_base;
state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[idx] = false;

Expand Down Expand Up @@ -698,6 +699,7 @@ void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchSta
}

for (const auto& [s, d, si, di] : desc) {
d->h_prompt_length[di] = s->h_prompt_length[si];
d->h_context_length[di] = s->h_context_length[si];
d->h_finished[di] = s->h_finished[si];
d->h_rope_theta[di] = s->h_rope_theta[si];
Expand Down Expand Up @@ -772,6 +774,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
h_bad_words_ =
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);

h_min_length_ = (int*)allocator_->reMalloc(h_min_length_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
Expand All @@ -794,6 +797,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
{"min_length", (std::byte*)h_min_length_, nullptr},
{"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
{"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
{"temperature", (std::byte*)h_temperature_, nullptr},
Expand Down Expand Up @@ -826,6 +830,8 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
(uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);

for (auto& s : states_) {
s.h_prompt_length =
(int*)allocator_->reMalloc(s.h_prompt_length, sizeof(int) * max_batch_size, false, true);
s.h_context_length =
(int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
Expand Down Expand Up @@ -1057,6 +1063,12 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
}
}

// MinLengthPenalty
if (inputs.isExist("min_length")) {
inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}});
inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}});
}

// init for eos
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
namespace turbomind {

struct BatchState {
int* h_prompt_length; // history + input, ignore generated
int* h_context_length;
bool* h_finished;

Expand Down Expand Up @@ -248,6 +249,7 @@ class LlamaBatch {
uintptr_t* h_k_block_ptrs_{};
uintptr_t* h_v_block_ptrs_{};

int* h_min_length_{};
int* h_runtime_top_k_{};
float* h_runtime_top_p_{};
float* h_temperature_{};
Expand Down
Loading