Skip to content

Commit

Permalink
split batch dim
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Sep 25, 2024
1 parent c64c00f commit 4066eb5
Showing 1 changed file with 116 additions and 57 deletions.
173 changes: 116 additions & 57 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,8 @@ def _split_embeddings(input_ids, niter, iter_len, embeddings,
logits = torch.cat(logits, dim=1)
return logits

def get_ppl(
self, input_ids: Union[List[int],
List[List[int]]]) -> Union[float, List[float]]:
def get_ppl(self, input_ids: Union[List[int],
List[List[int]]]) -> List[float]:
"""Get perplexity scores given a list of input tokens that have to be
of the same length.
Expand All @@ -197,68 +196,128 @@ def get_ppl(

generator = self.engine.create_instance()

bs = len(input_ids)
max_seq_len = max([len(_) for _ in input_ids])

# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)
max_input_len = 2 * 1024**3 // (vocab_size * 4)
sizes = [len(_) for _ in input_ids]
losses = []
target_counts = []
for (start, end) in self._batch_iterator(sizes, max_input_len):
if start == end:
loss, target_count = self._get_long_text_ppl(
generator=generator,
input_ids=input_ids[start],
max_input_len=max_input_len)
losses.append(loss)
target_counts.append(target_count)
else:
loss, target_count = self._get_ppl(
generator=generator,
input_ids=input_ids[start:end],
max_input_len=max_input_len,
)
losses.append(loss)
target_counts.append(target_count)
loss = torch.concatenate(losses)
target_count = torch.concatenate(target_counts)
loss_avg = loss / target_count
loss_avg = loss_avg.numpy()
return loss_avg

def _batch_iterator(self, sizes, max_length):
i = 0
while i < len(sizes):
current_sum = 0
start_index = i

while i < len(sizes) and current_sum + sizes[i] <= max_length:
current_sum += sizes[i]
i += 1

yield (start_index, i)
if i > start_index:
continue
else:
i += 1

def _get_long_text_ppl(self, generator, input_ids, max_input_len):
assert all(isinstance(_, int) for _ in input_ids)
assert len(input_ids) > max_input_len

seq_len = len(input_ids)
losses = []
target_counts = []
for i in range(0, max_seq_len, max_input_len):
token_ids = [
input_id[i:i + max_input_len] for input_id in input_ids
]
steps = [i] * bs
logits = generator.decode(
input_ids=token_ids,
steps=steps,
for i in range(0, len(input_ids), max_input_len):
token_ids = input_ids[i:i + max_input_len]
step = i
# shift token_ids by 1 to the left
target_ids = input_ids[i + 1:i + 1 + max_input_len]

loss, target_count = self._get_ppl(
generator=generator,
input_ids=[token_ids],
max_input_len=max_input_len,
target_ids=[target_ids],
steps=[step],
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= max_seq_len))
bsz, seq_len, vocab_size = logits.shape
logits = logits.float()
padding_token_id = -100
# meaning logits[..., :, :] corresponds to labels
# token_ids[1:] + predict_token_id, which is
# input_ids[:, i+max_input_len:i+max_input_len+1]
target_ids = [
input_id[i + 1:i + 1 + max_input_len] for input_id in input_ids
]
sequence_end=(i + max_input_len >= seq_len))
losses.append(loss)
target_counts.append(target_count)
loss_sum = torch.concatenate(losses).sum().unsqueeze(0)
target_count = torch.concatenate(target_counts).sum().unsqueeze(0)
return loss_sum, target_count

def _get_ppl(self,
generator,
input_ids,
max_input_len,
target_ids=None,
steps=None,
sequence_start: bool = True,
sequence_end: bool = True):
assert isinstance(input_ids, List)
assert all(isinstance(_, List) for _ in input_ids)
assert sum(len(_) for _ in input_ids) <= max_input_len
if target_ids:
assert all(isinstance(_, List) for _ in target_ids)

logger.info(f'get_ppl batch_size {len(input_ids)}')
logits = generator.decode(input_ids=input_ids,
steps=steps,
sequence_start=sequence_start,
sequence_end=sequence_end)
bsz, seq_len, vocab_size = logits.shape
logits = logits.float()
padding_token_id = -100
if target_ids is None:
# shift token_ids by 1 to the left
target_ids = [x[1:] + [padding_token_id] for x in input_ids]
else:
target_ids = [
target_ids[i] + [padding_token_id]
if len(target_ids[i]) < len(token_ids[i]) else target_ids[i]
if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
for i in range(bsz)
]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id

# compute cross entropy loss
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
flat_loss_matrix = flat_loss_matrix.view(bsz, seq_len)
loss = flat_loss_matrix.sum(dim=-1).cpu().view(bsz, -1)
target_count = target_mask.sum(dim=-1).cpu().view(bsz, -1)
losses.append(loss)
target_counts.append(target_count)

target_count = torch.concatenate(target_counts, dim=-1).sum(dim=-1)
loss_sum = torch.concatenate(losses, dim=-1).sum(dim=-1)

loss_avg = loss_sum / target_count
loss_avg = loss_avg.numpy()

return loss_avg
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id

# compute cross entropy loss
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
flat_loss_matrix = flat_loss_matrix.view(bsz, seq_len)
loss = flat_loss_matrix.sum(dim=-1).cpu()
target_count = target_mask.sum(dim=-1).cpu()
return loss, target_count

0 comments on commit 4066eb5

Please sign in to comment.