Skip to content

Commit

Permalink
apply torch.cuda.empty_cache()
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Sep 26, 2024
1 parent 4066eb5 commit c760794
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,36 +203,55 @@ def get_ppl(self, input_ids: Union[List[int],
sizes = [len(_) for _ in input_ids]
losses = []
target_counts = []
sorted_index_values = sorted(list(enumerate(sizes)),
key=lambda x: x[1],
reverse=True)
sizes = [value for index, value in sorted_index_values]
indices = [index for index, value in sorted_index_values]
logger.info(f'sorted sizes: {sizes}')
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
if start == end:
loss, target_count = self._get_long_text_ppl(
generator=generator,
input_ids=input_ids[start],
input_ids=_input_ids,
max_input_len=max_input_len)
losses.append(loss)
target_counts.append(target_count)
else:
loss, target_count = self._get_ppl(
generator=generator,
input_ids=input_ids[start:end],
input_ids=_input_ids,
max_input_len=max_input_len,
)
losses.append(loss)
target_counts.append(target_count)
loss = torch.concatenate(losses)
target_count = torch.concatenate(target_counts)
loss_avg = loss / target_count
loss_avg = loss_avg.numpy()
return loss_avg

def _batch_iterator(self, sizes, max_length):
loss_avg = loss_avg.numpy().tolist()
result = list(range(len(loss_avg)))
for index, sorted_index in enumerate(indices):
result[sorted_index] = loss_avg[index]
return result

def _batch_iterator(self, sizes, max_value):
"""Return an iterator that calculates intervals (start, end) of a
descend-order list, in which the sum of values in the range is the
maximum number not less than max_value. By "the sum of values",
here it means $$len(sizes[start:end]) * sizes[start]$$
"""
i = 0
while i < len(sizes):
current_sum = 0
start_index = i

while i < len(sizes) and current_sum + sizes[i] <= max_length:
current_sum += sizes[i]
while i < len(
sizes) and current_sum + sizes[start_index] <= max_value:
current_sum += sizes[start_index]
i += 1

yield (start_index, i)
Expand All @@ -242,24 +261,25 @@ def _batch_iterator(self, sizes, max_length):
i += 1

def _get_long_text_ppl(self, generator, input_ids, max_input_len):
assert all(isinstance(_, int) for _ in input_ids)
assert len(input_ids) > max_input_len
assert isinstance(input_ids, List) and len(input_ids) == 1
seq_len = len(input_ids[0])
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')

seq_len = len(input_ids)
losses = []
target_counts = []
for i in range(0, len(input_ids), max_input_len):
token_ids = input_ids[i:i + max_input_len]
step = i
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[:, i:i + max_input_len]
step = [i]
# shift token_ids by 1 to the left
target_ids = input_ids[i + 1:i + 1 + max_input_len]
target_ids = input_ids[:, i + 1:i + 1 + max_input_len]

loss, target_count = self._get_ppl(
generator=generator,
input_ids=[token_ids],
input_ids=token_ids,
max_input_len=max_input_len,
target_ids=[target_ids],
steps=[step],
target_ids=target_ids,
steps=step,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= seq_len))
losses.append(loss)
Expand All @@ -278,11 +298,16 @@ def _get_ppl(self,
sequence_end: bool = True):
assert isinstance(input_ids, List)
assert all(isinstance(_, List) for _ in input_ids)
assert sum(len(_) for _ in input_ids) <= max_input_len
if target_ids:
assert all(isinstance(_, List) for _ in target_ids)

logger.info(f'get_ppl batch_size {len(input_ids)}')
lens = [len(_) for _ in input_ids]
total_len = sum(lens)
assert sum(lens) <= max_input_len

logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
torch.cuda.empty_cache()
logits = generator.decode(input_ids=input_ids,
steps=steps,
sequence_start=sequence_start,
Expand Down

0 comments on commit c760794

Please sign in to comment.