From 938544b53d3172ba0a89261f7a2902dccbc53800 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 12 Dec 2023 18:44:34 +0800 Subject: [PATCH 1/4] add get_ppl --- lmdeploy/turbomind/turbomind.py | 92 +++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index ad7c0cb518..300358a3a8 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -10,7 +10,7 @@ from contextlib import contextmanager from queue import Queue from threading import Thread -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Union import numpy as np import torch @@ -605,7 +605,11 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): if stream_output and not stop: self.model_insts[0].unregister_callback() - def decode(self, input_ids): + def decode(self, + input_ids, + steps: List[int] = None, + sequence_start: bool = True, + sequence_end: bool = True): """Perform context decode on input tokens. Args: @@ -616,6 +620,9 @@ def decode(self, input_ids): input_ids = [[]] if isinstance(input_ids[0], int): input_ids = [input_ids] + if steps is None: + steps = [0] * len(input_ids) + assert isinstance(steps, List) and len(steps) == len(input_ids) # append an extra token since input_len-1 tokens will be # decoded by context decoder @@ -636,11 +643,17 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id) + print(f'input_ids.shape: {input_ids.shape}') + steps = torch.IntTensor([step for step in steps]) inputs = dict(input_ids=input_ids, input_lengths=input_lengths, request_output_len=_broadcast_np(0, dtype=np.uint32), - is_return_logits=_broadcast_np(1, np.uint32)) + is_return_logits=_broadcast_np(1, np.uint32), + START=_broadcast_np((1 if sequence_start else 0), + np.int32), + END=_broadcast_np((1 if sequence_end else 0), np.int32), + step=steps) tm_inputs = _np_dict_to_tm_dict(inputs) @@ -651,5 +664,78 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): outputs = _tm_dict_to_torch_dict(tm_outputs) logits = outputs['logits'] + print(f'logits.shape: {logits.shape}') return logits[:, :-1, :] + + def get_ppl(self, input_ids: Union[List[int], List[List[int]]]): + """Get perplexity scores given a list of input tokens. + + Args: + input_ids (Union[List[int], List[List[int]]]): the batch of input token ids + """ # noqa 501 + + if len(input_ids) == 0: + input_ids = [[]] + if isinstance(input_ids[0], int): + input_ids = [input_ids] + + def _get_ppl(logits, input_ids): + shift_logits = logits.contiguous().float() + shift_labels = torch.tensor(input_ids).to(shift_logits.device) + + # loss_fct = torch.nn.CrossEntropyLoss( + # reduction='none', ignore_index=self.tokenizer.pad_token_id) + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)).view(shift_labels.size()) + ce_loss = loss.sum(-1).cpu().detach().numpy() / len(input_ids) + return ce_loss + + batch_size = len(input_ids) + max_input_len = 16 * 1024 + # max_input_len = 16 + n_max_iter = np.ceil( + max([len(input_id) + for input_id in input_ids]) / max_input_len).astype(int) + + device = 'cuda' + if n_max_iter > 0: + device = 'cpu' + + index_range_starts = [] + index_range_ends = [] + for input_id in input_ids: + index_range_start = np.array( + [i * max_input_len for i in range(n_max_iter)]) + index_range_end = index_range_start + max_input_len + index_range_start[index_range_start >= len(input_id)] = len( + input_id) + index_range_end[index_range_end >= len(input_id)] = len(input_id) + index_range_starts.append(index_range_start) + index_range_ends.append(index_range_end) + + logits = [] + for i in range(n_max_iter): + steps = [start[i] for start in index_range_starts] + _input_ids = [ + input_id[start[i]:end[i]] for input_id, start, end in zip( + input_ids, index_range_starts, index_range_ends) + ] + print(f'the {i}-th: len {len(input_ids)}, {input_ids}') + + _logits = self.decode(_input_ids, + steps, + sequence_start=(i == 0), + sequence_end=(i == n_max_iter - 1)) + _logits = _logits.to(device=device) + logits.append(_logits) + + logits = torch.cat(logits, dim=1) + logits = torch.chunk(logits, chunks=batch_size, dim=0) + ppls = [] + for _logits, _input_ids in zip(logits, input_ids): + ppl = _get_ppl(_logits.squeeze(0), _input_ids) + ppls.append(ppl) + return ppls From 700141d43c034691a2e20399b9a56e9b3c63a042 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 12 Dec 2023 20:39:01 +0800 Subject: [PATCH 2/4] update --- lmdeploy/turbomind/turbomind.py | 60 ++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 300358a3a8..2550955c05 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -643,7 +643,6 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id) - print(f'input_ids.shape: {input_ids.shape}') steps = torch.IntTensor([step for step in steps]) inputs = dict(input_ids=input_ids, @@ -664,7 +663,6 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): outputs = _tm_dict_to_torch_dict(tm_outputs) logits = outputs['logits'] - print(f'logits.shape: {logits.shape}') return logits[:, :-1, :] @@ -680,29 +678,13 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]): if isinstance(input_ids[0], int): input_ids = [input_ids] - def _get_ppl(logits, input_ids): - shift_logits = logits.contiguous().float() - shift_labels = torch.tensor(input_ids).to(shift_logits.device) - - # loss_fct = torch.nn.CrossEntropyLoss( - # reduction='none', ignore_index=self.tokenizer.pad_token_id) - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1)).view(shift_labels.size()) - ce_loss = loss.sum(-1).cpu().detach().numpy() / len(input_ids) - return ce_loss - - batch_size = len(input_ids) max_input_len = 16 * 1024 # max_input_len = 16 n_max_iter = np.ceil( max([len(input_id) for input_id in input_ids]) / max_input_len).astype(int) - device = 'cuda' - if n_max_iter > 0: - device = 'cpu' + device = 'cpu' if n_max_iter > 0 else 'cuda' index_range_starts = [] index_range_ends = [] @@ -723,8 +705,6 @@ def _get_ppl(logits, input_ids): input_id[start[i]:end[i]] for input_id, start, end in zip( input_ids, index_range_starts, index_range_ends) ] - print(f'the {i}-th: len {len(input_ids)}, {input_ids}') - _logits = self.decode(_input_ids, steps, sequence_start=(i == 0), @@ -732,10 +712,36 @@ def _get_ppl(logits, input_ids): _logits = _logits.to(device=device) logits.append(_logits) + # concat logits. Shape is [bsz, seq_len, vocab_size] logits = torch.cat(logits, dim=1) - logits = torch.chunk(logits, chunks=batch_size, dim=0) - ppls = [] - for _logits, _input_ids in zip(logits, input_ids): - ppl = _get_ppl(_logits.squeeze(0), _input_ids) - ppls.append(ppl) - return ppls + + # get target ids + padding_token_id = -100 + target_ids = [(_input_ids + [padding_token_id])[1:] + for _input_ids in input_ids] + target_ids = [ + torch.Tensor(torch.LongTensor(_target_ids)) + for _target_ids in target_ids + ] + target_ids = pad_sequence(target_ids, + batch_first=True, + padding_value=padding_token_id) + target_ids = target_ids.to(logits.device) + target_mask = target_ids != padding_token_id + target_count = torch.sum(target_mask, dim=-1) + + # compute cross entropy loss + bsz, seq_len, vocab_size = logits.shape + flat_logits = logits.contiguous().view(-1, vocab_size) + flat_target_ids = target_ids.contiguous().view(-1) + flat_loss_matrix = torch.nn.functional.cross_entropy( + flat_logits, + flat_target_ids, + reduction='none', + ignore_index=padding_token_id) + + loss_matrix = flat_loss_matrix.view(bsz, seq_len) + loss_sum = torch.sum(loss_matrix * target_mask, dim=1) + loss_avg = loss_sum / target_count + loss_avg = loss_avg.cpu().numpy() + return loss_avg From 5b1cbad48a49935fbf613e927d2260a2843d264e Mon Sep 17 00:00:00 2001 From: irexyc Date: Tue, 26 Dec 2023 08:54:16 +0000 Subject: [PATCH 3/4] fix get_ppl --- src/turbomind/models/llama/LlamaBatch.cc | 28 +++++++++++++++++++----- src/turbomind/models/llama/LlamaBatch.h | 6 +++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 65e2c53aa5..6444491a89 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1065,9 +1065,10 @@ void LlamaBatch::InitializeSampling(const GenerationState& g) } template -void LlamaBatch::OutputContextLogits(T* context_decoder_output, - const std::vector& indices, - const std::vector& lengths) +void LlamaBatch::OutputContextLogits(T* context_decoder_output, + const std::vector& indices, + const std::vector& lengths, + const std::vector& sequences) { std::vector output_logits; int num_token = 0; @@ -1075,7 +1076,11 @@ void LlamaBatch::OutputContextLogits(T* context_decoder_ bool is_return_logits = false; for (int k = 0; k < indices.size(); ++k) { auto& request = state_->requests[indices[k]]; - output_logits.push_back(request->outputs[rank_].getPtr("logits", nullptr)); + auto logits = request->outputs[rank_].getPtr("logits", nullptr); + if (logits && sequences[k]->cache_len + lengths[k] <= sequences[k]->tokens.size()) { + logits = nullptr; + } + output_logits.push_back(logits); num_token += lengths[k]; if (output_logits.back()) { is_return_logits = true; @@ -1105,7 +1110,18 @@ void LlamaBatch::OutputContextLogits(T* context_decoder_ for (int k = 0; k < indices.size(); ++k) { if (output_logits[k]) { - Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]); + auto src_ptr = logits; + auto dst_ptr = output_logits[k]; + int num_new_token = 0; + if (sequences[k]->cache_len < sequences[k]->tokens.size()) { + num_new_token = sequences[k]->cache_len + lengths[k] - sequences[k]->tokens.size(); + src_ptr += (lengths[k] - num_new_token) * model_->vocab_size_; + } + else { + num_new_token = lengths[k]; + dst_ptr += (sequences[k]->cache_len - sequences[k]->tokens.size()) * model_->vocab_size_; + } + Copy(src_ptr, model_->vocab_size_ * num_new_token, dst_ptr); } logits += model_->vocab_size_padded_ * lengths[k]; } @@ -1562,7 +1578,7 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) if (iter == 0) { // compute logits of inputs if requested - OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths); + OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths, sequences); } } diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 0173ddfceb..9af3b7522f 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -92,8 +92,10 @@ class LlamaBatch { [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false); - void - OutputContextLogits(T* context_decoder_output, const std::vector& indices, const std::vector& lengths); + void OutputContextLogits(T* context_decoder_output, + const std::vector& indices, + const std::vector& lengths, + const std::vector& sequences); explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model); From 47289703383b23891968952fc480d216cd72b4a4 Mon Sep 17 00:00:00 2001 From: Xin Chen Date: Wed, 27 Dec 2023 13:33:13 +0800 Subject: [PATCH 4/4] update docstring --- lmdeploy/turbomind/turbomind.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index afb2263ad9..c6ef9d621f 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -806,6 +806,9 @@ def decode(self, Args: input_ids (numpy.ndarray): the batch of input token ids + steps (List[int]): the offset of the k/v cache + sequence_start (bool): indicator for starting a sequence + sequence_end (bool): indicator for ending a sequence """ if len(input_ids) == 0: