diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index a11f18c2f3..5a898cdf4e 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -1030,34 +1030,3 @@ async def async_end(self, session_id: int): def end(self, session_id: int): """Add new session.""" return self.engine_instance.end(session_id) - - def decode(self, - input_ids, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, - steps: List[int] = None, - sequence_start: bool = True, - sequence_end: bool = True, - adapter_names: List[str] = None): - """Perform context decode on input tokens. - - Args: - input_ids (List[List[int]] | List[np.ndaray]): the batch of input - token ids - steps (List[int]): the offset of the k/v cache - input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]): - embeddings features - input_embedding_ranges: (List[List[Tuple[int, int]]]): - the begin/end offsets of input_embeddings to input_ids - sequence_start (bool): indicator for starting a sequence - sequence_end (bool): indicator for ending a sequence - adapter_names (List[str]): The name of the adapters. - """ - return self.engine_instance.decode( - input_ids, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - steps=steps, - sequence_start=sequence_start, - sequence_end=sequence_end, - adapter_names=adapter_names) diff --git a/lmdeploy/serve/utils.py b/lmdeploy/serve/utils.py index 8945a71a7f..4791d3c724 100644 --- a/lmdeploy/serve/utils.py +++ b/lmdeploy/serve/utils.py @@ -81,7 +81,12 @@ def get_logits( for input_id in input_ids: assert len(input_id) > 0 - max_input_len = self.backend_config.max_prefill_token_num + bs = len(input_ids) + # TODO: a better way to determine `max_input_len`, at most allocate + # 2G mem for logits with shape [bs, max_input_len, vocab_size] + vocab_size = self.hf_tm_cfg.vocab_size + max_input_len = 2 * 1024**3 // (bs * vocab_size * 4) + n_max_iter = np.ceil( max([len(input_id) for input_id in input_ids]) / max_input_len).astype(int) @@ -173,79 +178,171 @@ def _split_embeddings(input_ids, niter, iter_len, embeddings, logits = torch.cat(logits, dim=1) return logits - def get_ppl(self, input_ids: Union[List[int], List[List[int]]]): - """Get perplexity scores given a list of input tokens. + def get_ppl(self, input_ids: Union[List[int], + List[List[int]]]) -> List[float]: + """Get perplexity scores given a list of input tokens that have to be + of the same length. Args: input_ids (Union[List[int], List[List[int]]]): the batch of input token ids + + Returns: + Union[float, List[float]]: A list of perplexity scores. """ - assert len(input_ids) > 0 + assert isinstance(input_ids, List) if isinstance(input_ids[0], int): input_ids = [input_ids] - for input_id in input_ids: - assert len(input_id) > 1 - - max_input_len = self.backend_config.max_prefill_token_num - n_max_iter = np.ceil( - max([len(input_id) - for input_id in input_ids]) / max_input_len).astype(int) - - index_range_starts = [] - index_range_ends = [] - for input_id in input_ids: - index_range_start = np.array( - [i * max_input_len for i in range(n_max_iter)]) - index_range_end = index_range_start + max_input_len - index_range_start[index_range_start >= len(input_id)] = len( - input_id) - index_range_end[index_range_end >= len(input_id)] = len(input_id) - index_range_starts.append(index_range_start) - index_range_ends.append(index_range_end) generator = self.engine.create_instance() - all_loss_matrix = [] - all_target_mask = [] - for i in range(n_max_iter): - steps = [start[i] for start in index_range_starts] - _input_ids = [ - input_id[start[i]:end[i]] for input_id, start, end in zip( - input_ids, index_range_starts, index_range_ends) - ] - _logits = generator.decode(_input_ids, - steps=steps, - sequence_start=(i == 0), - sequence_end=(i == n_max_iter - 1)) - _logits = _logits.float().cpu() - padding_token_id = -100 - target_ids = [(x + [padding_token_id])[1:] for x in _input_ids] + + # TODO: a better way to determine `max_input_len`, at most allocate + # 2G mem for logits with shape [bs, max_input_len, vocab_size] + vocab_size = self.hf_tm_cfg.vocab_size + max_input_len = 2 * 1024**3 // (vocab_size * 4) + sizes = [len(_) for _ in input_ids] + losses = [] + target_counts = [] + sorted_index_values = sorted(list(enumerate(sizes)), + key=lambda x: x[1], + reverse=True) + sizes = [value for index, value in sorted_index_values] + indices = [index for index, value in sorted_index_values] + logger.info(f'sorted sizes: {sizes}') + logger.info(f'sorted indices: {indices}') + for (start, end) in self._batch_iterator(sizes, max_input_len): + logger.info(f'start: {start}, end: {end}') + _input_ids = [input_ids[indices[i]] for i in range(start, end)] + if start == end: + loss, target_count = self._get_long_text_ppl( + generator=generator, + input_ids=_input_ids, + max_input_len=max_input_len) + losses.append(loss) + target_counts.append(target_count) + else: + loss, target_count = self._get_ppl( + generator=generator, + input_ids=_input_ids, + max_input_len=max_input_len, + ) + losses.append(loss) + target_counts.append(target_count) + loss = torch.concatenate(losses) + target_count = torch.concatenate(target_counts) + loss_avg = loss / target_count + loss_avg = loss_avg.numpy().tolist() + result = list(range(len(loss_avg))) + for index, sorted_index in enumerate(indices): + result[sorted_index] = loss_avg[index] + return result + + def _batch_iterator(self, sizes, max_value): + """Return an iterator that calculates intervals (start, end) of a + descend-order list, in which the sum of values in the range is the + maximum number not less than max_value. By "the sum of values", + + here it means $$len(sizes[start:end]) * sizes[start]$$ + """ + i = 0 + while i < len(sizes): + current_sum = 0 + start_index = i + + while i < len( + sizes) and current_sum + sizes[start_index] <= max_value: + current_sum += sizes[start_index] + i += 1 + + yield (start_index, i) + if i > start_index: + continue + else: + i += 1 + + def _get_long_text_ppl(self, generator, input_ids, max_input_len): + assert isinstance(input_ids, List) and len(input_ids) == 1 + seq_len = len(input_ids[0]) + assert seq_len > max_input_len + logger.info(f'get long text ppl: seq_len {seq_len}') + + losses = [] + target_counts = [] + for i in range(0, seq_len, max_input_len): + token_ids = input_ids[:, i:i + max_input_len] + step = [i] + # shift token_ids by 1 to the left + target_ids = input_ids[:, i + 1:i + 1 + max_input_len] + + loss, target_count = self._get_ppl( + generator=generator, + input_ids=token_ids, + max_input_len=max_input_len, + target_ids=target_ids, + steps=step, + sequence_start=(i == 0), + sequence_end=(i + max_input_len >= seq_len)) + losses.append(loss) + target_counts.append(target_count) + loss_sum = torch.concatenate(losses).sum().unsqueeze(0) + target_count = torch.concatenate(target_counts).sum().unsqueeze(0) + return loss_sum, target_count + + def _get_ppl(self, + generator, + input_ids, + max_input_len, + target_ids=None, + steps=None, + sequence_start: bool = True, + sequence_end: bool = True): + assert isinstance(input_ids, List) + assert all(isinstance(_, List) for _ in input_ids) + if target_ids: + assert all(isinstance(_, List) for _ in target_ids) + + lens = [len(_) for _ in input_ids] + total_len = sum(lens) + assert sum(lens) <= max_input_len + + logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, ' + f'total_len: {total_len}') + torch.cuda.empty_cache() + logits = generator.decode(input_ids=input_ids, + steps=steps, + sequence_start=sequence_start, + sequence_end=sequence_end) + bsz, seq_len, vocab_size = logits.shape + logits = logits.float() + padding_token_id = -100 + if target_ids is None: + # shift token_ids by 1 to the left + target_ids = [x[1:] + [padding_token_id] for x in input_ids] + else: target_ids = [ - torch.Tensor(torch.LongTensor(_target_ids)) - for _target_ids in target_ids + target_ids[i] + [padding_token_id] + if len(target_ids[i]) < len(input_ids[i]) else target_ids[i] + for i in range(bsz) ] - target_ids = pad_sequence(target_ids, - batch_first=True, - padding_value=padding_token_id) - target_ids = target_ids.to(_logits.device) - target_mask = target_ids != padding_token_id - target_count = torch.sum(target_mask, dim=-1) - # compute cross entropy loss - bsz, seq_len, vocab_size = _logits.shape - flat_logits = _logits.contiguous().view(-1, vocab_size) - flat_target_ids = target_ids.contiguous().view(-1) - flat_loss_matrix = torch.nn.functional.cross_entropy( - flat_logits, - flat_target_ids, - reduction='none', - ignore_index=padding_token_id) - - all_loss_matrix.append(flat_loss_matrix.view(bsz, seq_len)) - all_target_mask.append(target_mask) - - all_loss_matrix = torch.cat(all_loss_matrix, dim=1) - all_target_mask = torch.cat(all_target_mask, dim=1) - target_count = torch.sum(all_target_mask, dim=-1) - loss_sum = torch.sum(all_loss_matrix * all_target_mask, dim=1) - loss_avg = loss_sum / target_count - loss_avg = loss_avg.cpu().numpy() - return loss_avg + target_ids = [ + torch.Tensor(torch.LongTensor(_target_ids)) + for _target_ids in target_ids + ] + target_ids = pad_sequence(target_ids, + batch_first=True, + padding_value=padding_token_id) + target_ids = target_ids.to(logits.device) + target_mask = target_ids != padding_token_id + + # compute cross entropy loss + flat_logits = logits.contiguous().view(-1, vocab_size) + flat_target_ids = target_ids.contiguous().view(-1) + flat_loss_matrix = torch.nn.functional.cross_entropy( + flat_logits, + flat_target_ids, + reduction='none', + ignore_index=padding_token_id) + flat_loss_matrix = flat_loss_matrix.view(bsz, seq_len) + loss = flat_loss_matrix.sum(dim=-1).cpu() + target_count = target_mask.sum(dim=-1).cpu() + return loss, target_count diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index bec6120b7b..4ee464e46d 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -135,5 +135,9 @@ def weight_type(self): def group_size(self): return self.model_config.group_size + @property + def vocab_size(self): + return self.model_config.vocab_size + def __str__(self): return json.dumps(self.to_dict(), indent=2) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 00b419ded1..05bc3e400e 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -8,7 +8,7 @@ from dataclasses import asdict from itertools import repeat from queue import LifoQueue, Queue -from typing import Dict, Iterable, List, Union +from typing import Dict, Iterable, List import numpy as np import torch @@ -314,7 +314,7 @@ def create_instance(self, cuda_stream_id=0): Returns: TurboMindInstance: an instance of turbomind """ - return TurboMindInstance(self, cuda_stream_id) + return TurboMindInstance(self, self.config, cuda_stream_id) class TurboMindInstance: @@ -325,7 +325,10 @@ class TurboMindInstance: cuda_stream_id(int): identity of a cuda stream """ - def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0): + def __init__(self, + tm_model: TurboMind, + config: TurbomindModelConfig, + cuda_stream_id: int = 0): self.tm_model = tm_model self.cuda_stream_id = cuda_stream_id @@ -343,6 +346,7 @@ def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0): self.que = Queue() self.executor: ThreadPoolExecutor = None self.future = None + self.config = config def _create_model_instance(self, device_id): rank = self.node_id * self.gpu_count + device_id @@ -922,85 +926,3 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): logits = outputs['logits'] return logits[:, :-1, :] - - def get_ppl(self, input_ids: Union[List[int], List[List[int]]]): - """Get perplexity scores given a list of input tokens. - - Args: - input_ids (Union[List[int], List[List[int]]]): the batch of input token ids - """ # noqa 501 - - if len(input_ids) == 0: - input_ids = [[]] - if isinstance(input_ids[0], int): - input_ids = [input_ids] - - max_input_len = 16 * 1024 - # max_input_len = 16 - n_max_iter = np.ceil( - max([len(input_id) - for input_id in input_ids]) / max_input_len).astype(int) - - device = 'cpu' if n_max_iter > 1 else 'cuda' - - index_range_starts = [] - index_range_ends = [] - for input_id in input_ids: - index_range_start = np.array( - [i * max_input_len for i in range(n_max_iter)]) - index_range_end = index_range_start + max_input_len - index_range_start[index_range_start >= len(input_id)] = len( - input_id) - index_range_end[index_range_end >= len(input_id)] = len(input_id) - index_range_starts.append(index_range_start) - index_range_ends.append(index_range_end) - - logits = [] - for i in range(n_max_iter): - steps = [start[i] for start in index_range_starts] - _input_ids = [ - input_id[start[i]:end[i]] for input_id, start, end in zip( - input_ids, index_range_starts, index_range_ends) - ] - _logits = self.decode(_input_ids, - steps, - sequence_start=(i == 0), - sequence_end=(i == n_max_iter - 1)) - if _logits is None: - return None - _logits = _logits.to(device=device) - logits.append(_logits) - - # concat logits. Shape is [bsz, seq_len, vocab_size] - logits = torch.cat(logits, dim=1) - - # get target ids - padding_token_id = -100 - target_ids = [(_input_ids + [padding_token_id])[1:] - for _input_ids in input_ids] - target_ids = [ - torch.Tensor(torch.LongTensor(_target_ids)) - for _target_ids in target_ids - ] - target_ids = pad_sequence(target_ids, - batch_first=True, - padding_value=padding_token_id) - target_ids = target_ids.to(logits.device) - target_mask = target_ids != padding_token_id - target_count = torch.sum(target_mask, dim=-1) - - # compute cross entropy loss - bsz, seq_len, vocab_size = logits.shape - flat_logits = logits.contiguous().view(-1, vocab_size) - flat_target_ids = target_ids.contiguous().view(-1) - flat_loss_matrix = torch.nn.functional.cross_entropy( - flat_logits, - flat_target_ids, - reduction='none', - ignore_index=padding_token_id) - - loss_matrix = flat_loss_matrix.view(bsz, seq_len) - loss_sum = torch.sum(loss_matrix * target_mask, dim=1) - loss_avg = loss_sum / target_count - loss_avg = loss_avg.cpu().numpy() - return loss_avg