diff --git a/lmdeploy/serve/utils.py b/lmdeploy/serve/utils.py index 44476a958b..4791d3c724 100644 --- a/lmdeploy/serve/utils.py +++ b/lmdeploy/serve/utils.py @@ -203,18 +203,27 @@ def get_ppl(self, input_ids: Union[List[int], sizes = [len(_) for _ in input_ids] losses = [] target_counts = [] + sorted_index_values = sorted(list(enumerate(sizes)), + key=lambda x: x[1], + reverse=True) + sizes = [value for index, value in sorted_index_values] + indices = [index for index, value in sorted_index_values] + logger.info(f'sorted sizes: {sizes}') + logger.info(f'sorted indices: {indices}') for (start, end) in self._batch_iterator(sizes, max_input_len): + logger.info(f'start: {start}, end: {end}') + _input_ids = [input_ids[indices[i]] for i in range(start, end)] if start == end: loss, target_count = self._get_long_text_ppl( generator=generator, - input_ids=input_ids[start], + input_ids=_input_ids, max_input_len=max_input_len) losses.append(loss) target_counts.append(target_count) else: loss, target_count = self._get_ppl( generator=generator, - input_ids=input_ids[start:end], + input_ids=_input_ids, max_input_len=max_input_len, ) losses.append(loss) @@ -222,17 +231,27 @@ def get_ppl(self, input_ids: Union[List[int], loss = torch.concatenate(losses) target_count = torch.concatenate(target_counts) loss_avg = loss / target_count - loss_avg = loss_avg.numpy() - return loss_avg - - def _batch_iterator(self, sizes, max_length): + loss_avg = loss_avg.numpy().tolist() + result = list(range(len(loss_avg))) + for index, sorted_index in enumerate(indices): + result[sorted_index] = loss_avg[index] + return result + + def _batch_iterator(self, sizes, max_value): + """Return an iterator that calculates intervals (start, end) of a + descend-order list, in which the sum of values in the range is the + maximum number not less than max_value. By "the sum of values", + + here it means $$len(sizes[start:end]) * sizes[start]$$ + """ i = 0 while i < len(sizes): current_sum = 0 start_index = i - while i < len(sizes) and current_sum + sizes[i] <= max_length: - current_sum += sizes[i] + while i < len( + sizes) and current_sum + sizes[start_index] <= max_value: + current_sum += sizes[start_index] i += 1 yield (start_index, i) @@ -242,24 +261,25 @@ def _batch_iterator(self, sizes, max_length): i += 1 def _get_long_text_ppl(self, generator, input_ids, max_input_len): - assert all(isinstance(_, int) for _ in input_ids) - assert len(input_ids) > max_input_len + assert isinstance(input_ids, List) and len(input_ids) == 1 + seq_len = len(input_ids[0]) + assert seq_len > max_input_len + logger.info(f'get long text ppl: seq_len {seq_len}') - seq_len = len(input_ids) losses = [] target_counts = [] - for i in range(0, len(input_ids), max_input_len): - token_ids = input_ids[i:i + max_input_len] - step = i + for i in range(0, seq_len, max_input_len): + token_ids = input_ids[:, i:i + max_input_len] + step = [i] # shift token_ids by 1 to the left - target_ids = input_ids[i + 1:i + 1 + max_input_len] + target_ids = input_ids[:, i + 1:i + 1 + max_input_len] loss, target_count = self._get_ppl( generator=generator, - input_ids=[token_ids], + input_ids=token_ids, max_input_len=max_input_len, - target_ids=[target_ids], - steps=[step], + target_ids=target_ids, + steps=step, sequence_start=(i == 0), sequence_end=(i + max_input_len >= seq_len)) losses.append(loss) @@ -278,11 +298,16 @@ def _get_ppl(self, sequence_end: bool = True): assert isinstance(input_ids, List) assert all(isinstance(_, List) for _ in input_ids) - assert sum(len(_) for _ in input_ids) <= max_input_len if target_ids: assert all(isinstance(_, List) for _ in target_ids) - logger.info(f'get_ppl batch_size {len(input_ids)}') + lens = [len(_) for _ in input_ids] + total_len = sum(lens) + assert sum(lens) <= max_input_len + + logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, ' + f'total_len: {total_len}') + torch.cuda.empty_cache() logits = generator.decode(input_ids=input_ids, steps=steps, sequence_start=sequence_start,