diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index b1ad357462..c6ef9d621f 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -10,7 +10,7 @@ from contextlib import contextmanager from queue import Queue from threading import Thread -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Union import numpy as np import torch @@ -797,17 +797,27 @@ def stream_infer(self, if stream_output and not stop: self.model_insts[0].unregister_callback() - def decode(self, input_ids): + def decode(self, + input_ids, + steps: List[int] = None, + sequence_start: bool = True, + sequence_end: bool = True): """Perform context decode on input tokens. Args: input_ids (numpy.ndarray): the batch of input token ids + steps (List[int]): the offset of the k/v cache + sequence_start (bool): indicator for starting a sequence + sequence_end (bool): indicator for ending a sequence """ if len(input_ids) == 0: input_ids = [[]] if isinstance(input_ids[0], int): input_ids = [input_ids] + if steps is None: + steps = [0] * len(input_ids) + assert isinstance(steps, List) and len(steps) == len(input_ids) # append an extra token since input_len-1 tokens will be # decoded by context decoder @@ -828,11 +838,16 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id) + steps = torch.IntTensor([step for step in steps]) inputs = dict(input_ids=input_ids, input_lengths=input_lengths, request_output_len=_broadcast_np(0, dtype=np.uint32), - is_return_logits=_broadcast_np(1, np.uint32)) + is_return_logits=_broadcast_np(1, np.uint32), + START=_broadcast_np((1 if sequence_start else 0), + np.int32), + END=_broadcast_np((1 if sequence_end else 0), np.int32), + step=steps) tm_inputs = _np_dict_to_tm_dict(inputs) @@ -845,3 +860,83 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): logits = outputs['logits'] return logits[:, :-1, :] + + def get_ppl(self, input_ids: Union[List[int], List[List[int]]]): + """Get perplexity scores given a list of input tokens. + + Args: + input_ids (Union[List[int], List[List[int]]]): the batch of input token ids + """ # noqa 501 + + if len(input_ids) == 0: + input_ids = [[]] + if isinstance(input_ids[0], int): + input_ids = [input_ids] + + max_input_len = 16 * 1024 + # max_input_len = 16 + n_max_iter = np.ceil( + max([len(input_id) + for input_id in input_ids]) / max_input_len).astype(int) + + device = 'cpu' if n_max_iter > 0 else 'cuda' + + index_range_starts = [] + index_range_ends = [] + for input_id in input_ids: + index_range_start = np.array( + [i * max_input_len for i in range(n_max_iter)]) + index_range_end = index_range_start + max_input_len + index_range_start[index_range_start >= len(input_id)] = len( + input_id) + index_range_end[index_range_end >= len(input_id)] = len(input_id) + index_range_starts.append(index_range_start) + index_range_ends.append(index_range_end) + + logits = [] + for i in range(n_max_iter): + steps = [start[i] for start in index_range_starts] + _input_ids = [ + input_id[start[i]:end[i]] for input_id, start, end in zip( + input_ids, index_range_starts, index_range_ends) + ] + _logits = self.decode(_input_ids, + steps, + sequence_start=(i == 0), + sequence_end=(i == n_max_iter - 1)) + _logits = _logits.to(device=device) + logits.append(_logits) + + # concat logits. Shape is [bsz, seq_len, vocab_size] + logits = torch.cat(logits, dim=1) + + # get target ids + padding_token_id = -100 + target_ids = [(_input_ids + [padding_token_id])[1:] + for _input_ids in input_ids] + target_ids = [ + torch.Tensor(torch.LongTensor(_target_ids)) + for _target_ids in target_ids + ] + target_ids = pad_sequence(target_ids, + batch_first=True, + padding_value=padding_token_id) + target_ids = target_ids.to(logits.device) + target_mask = target_ids != padding_token_id + target_count = torch.sum(target_mask, dim=-1) + + # compute cross entropy loss + bsz, seq_len, vocab_size = logits.shape + flat_logits = logits.contiguous().view(-1, vocab_size) + flat_target_ids = target_ids.contiguous().view(-1) + flat_loss_matrix = torch.nn.functional.cross_entropy( + flat_logits, + flat_target_ids, + reduction='none', + ignore_index=padding_token_id) + + loss_matrix = flat_loss_matrix.view(bsz, seq_len) + loss_sum = torch.sum(loss_matrix * target_mask, dim=1) + loss_avg = loss_sum / target_count + loss_avg = loss_avg.cpu().numpy() + return loss_avg diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 65e2c53aa5..6444491a89 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1065,9 +1065,10 @@ void LlamaBatch::InitializeSampling(const GenerationState& g) } template -void LlamaBatch::OutputContextLogits(T* context_decoder_output, - const std::vector& indices, - const std::vector& lengths) +void LlamaBatch::OutputContextLogits(T* context_decoder_output, + const std::vector& indices, + const std::vector& lengths, + const std::vector& sequences) { std::vector output_logits; int num_token = 0; @@ -1075,7 +1076,11 @@ void LlamaBatch::OutputContextLogits(T* context_decoder_ bool is_return_logits = false; for (int k = 0; k < indices.size(); ++k) { auto& request = state_->requests[indices[k]]; - output_logits.push_back(request->outputs[rank_].getPtr("logits", nullptr)); + auto logits = request->outputs[rank_].getPtr("logits", nullptr); + if (logits && sequences[k]->cache_len + lengths[k] <= sequences[k]->tokens.size()) { + logits = nullptr; + } + output_logits.push_back(logits); num_token += lengths[k]; if (output_logits.back()) { is_return_logits = true; @@ -1105,7 +1110,18 @@ void LlamaBatch::OutputContextLogits(T* context_decoder_ for (int k = 0; k < indices.size(); ++k) { if (output_logits[k]) { - Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]); + auto src_ptr = logits; + auto dst_ptr = output_logits[k]; + int num_new_token = 0; + if (sequences[k]->cache_len < sequences[k]->tokens.size()) { + num_new_token = sequences[k]->cache_len + lengths[k] - sequences[k]->tokens.size(); + src_ptr += (lengths[k] - num_new_token) * model_->vocab_size_; + } + else { + num_new_token = lengths[k]; + dst_ptr += (sequences[k]->cache_len - sequences[k]->tokens.size()) * model_->vocab_size_; + } + Copy(src_ptr, model_->vocab_size_ * num_new_token, dst_ptr); } logits += model_->vocab_size_padded_ * lengths[k]; } @@ -1562,7 +1578,7 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) if (iter == 0) { // compute logits of inputs if requested - OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths); + OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths, sequences); } } diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 0173ddfceb..9af3b7522f 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -92,8 +92,10 @@ class LlamaBatch { [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false); - void - OutputContextLogits(T* context_decoder_output, const std::vector& indices, const std::vector& lengths); + void OutputContextLogits(T* context_decoder_output, + const std::vector& indices, + const std::vector& lengths, + const std::vector& sequences); explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model);