Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute cross entropy loss given a list of input tokens #830

Merged
merged 5 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 98 additions & 3 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from contextlib import contextmanager
from queue import Queue
from threading import Thread
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -797,17 +797,27 @@ def stream_infer(self,
if stream_output and not stop:
self.model_insts[0].unregister_callback()

def decode(self, input_ids):
def decode(self,
input_ids,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True):
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
"""Perform context decode on input tokens.

Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
"""

if len(input_ids) == 0:
input_ids = [[]]
if isinstance(input_ids[0], int):
input_ids = [input_ids]
if steps is None:
steps = [0] * len(input_ids)
assert isinstance(steps, List) and len(steps) == len(input_ids)

# append an extra token since input_len-1 tokens will be
# decoded by context decoder
Expand All @@ -828,11 +838,16 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
input_ids = pad_sequence(input_ids,
batch_first=True,
padding_value=self.eos_id)
steps = torch.IntTensor([step for step in steps])

inputs = dict(input_ids=input_ids,
input_lengths=input_lengths,
request_output_len=_broadcast_np(0, dtype=np.uint32),
is_return_logits=_broadcast_np(1, np.uint32))
is_return_logits=_broadcast_np(1, np.uint32),
START=_broadcast_np((1 if sequence_start else 0),
np.int32),
END=_broadcast_np((1 if sequence_end else 0), np.int32),
step=steps)

tm_inputs = _np_dict_to_tm_dict(inputs)

Expand All @@ -845,3 +860,83 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
logits = outputs['logits']

return logits[:, :-1, :]

def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
"""Get perplexity scores given a list of input tokens.

Args:
input_ids (Union[List[int], List[List[int]]]): the batch of input token ids
""" # noqa 501

if len(input_ids) == 0:
input_ids = [[]]
if isinstance(input_ids[0], int):
input_ids = [input_ids]

max_input_len = 16 * 1024
# max_input_len = 16
n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)

device = 'cpu' if n_max_iter > 0 else 'cuda'

index_range_starts = []
index_range_ends = []
for input_id in input_ids:
index_range_start = np.array(
[i * max_input_len for i in range(n_max_iter)])
index_range_end = index_range_start + max_input_len
index_range_start[index_range_start >= len(input_id)] = len(
input_id)
index_range_end[index_range_end >= len(input_id)] = len(input_id)
index_range_starts.append(index_range_start)
index_range_ends.append(index_range_end)

logits = []
for i in range(n_max_iter):
steps = [start[i] for start in index_range_starts]
_input_ids = [
input_id[start[i]:end[i]] for input_id, start, end in zip(
input_ids, index_range_starts, index_range_ends)
]
_logits = self.decode(_input_ids,
steps,
sequence_start=(i == 0),
sequence_end=(i == n_max_iter - 1))
_logits = _logits.to(device=device)
logits.append(_logits)

# concat logits. Shape is [bsz, seq_len, vocab_size]
logits = torch.cat(logits, dim=1)

# get target ids
padding_token_id = -100
target_ids = [(_input_ids + [padding_token_id])[1:]
for _input_ids in input_ids]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id
target_count = torch.sum(target_mask, dim=-1)

# compute cross entropy loss
bsz, seq_len, vocab_size = logits.shape
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)

loss_matrix = flat_loss_matrix.view(bsz, seq_len)
loss_sum = torch.sum(loss_matrix * target_mask, dim=1)
loss_avg = loss_sum / target_count
loss_avg = loss_avg.cpu().numpy()
return loss_avg
28 changes: 22 additions & 6 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1065,17 +1065,22 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
}

template<typename T>
void LlamaBatch<T>::OutputContextLogits(T* context_decoder_output,
const std::vector<int>& indices,
const std::vector<int>& lengths)
void LlamaBatch<T>::OutputContextLogits(T* context_decoder_output,
const std::vector<int>& indices,
const std::vector<int>& lengths,
const std::vector<const Sequence*>& sequences)
{
std::vector<float*> output_logits;
int num_token = 0;
{
bool is_return_logits = false;
for (int k = 0; k < indices.size(); ++k) {
auto& request = state_->requests[indices[k]];
output_logits.push_back(request->outputs[rank_].getPtr<float>("logits", nullptr));
auto logits = request->outputs[rank_].getPtr<float>("logits", nullptr);
if (logits && sequences[k]->cache_len + lengths[k] <= sequences[k]->tokens.size()) {
logits = nullptr;
}
output_logits.push_back(logits);
num_token += lengths[k];
if (output_logits.back()) {
is_return_logits = true;
Expand Down Expand Up @@ -1105,7 +1110,18 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_

for (int k = 0; k < indices.size(); ++k) {
if (output_logits[k]) {
Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]);
auto src_ptr = logits;
auto dst_ptr = output_logits[k];
int num_new_token = 0;
if (sequences[k]->cache_len < sequences[k]->tokens.size()) {
num_new_token = sequences[k]->cache_len + lengths[k] - sequences[k]->tokens.size();
src_ptr += (lengths[k] - num_new_token) * model_->vocab_size_;
}
else {
num_new_token = lengths[k];
dst_ptr += (sequences[k]->cache_len - sequences[k]->tokens.size()) * model_->vocab_size_;
}
Copy(src_ptr, model_->vocab_size_ * num_new_token, dst_ptr);
}
logits += model_->vocab_size_padded_ * lengths[k];
}
Expand Down Expand Up @@ -1562,7 +1578,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)

if (iter == 0) {
// compute logits of inputs if requested
OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths, sequences);
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ class LlamaBatch {

[[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);

void
OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
void OutputContextLogits(T* context_decoder_output,
const std::vector<int>& indices,
const std::vector<int>& lengths,
const std::vector<const Sequence*>& sequences);

explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2<T>* model);

Expand Down
Loading