Skip to content

Commit

Permalink
fix logits vocab (#1786)
Browse files Browse the repository at this point in the history
* fix logits vocab

* sort hypotheses
  • Loading branch information
minhthuc2502 authored Sep 25, 2024
1 parent 5ad036d commit 6ebddf3
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/cpp/generation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace ctranslate2 {
.def_readonly("scores", &GenerationResult::scores,
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
.def_readonly("logits", &GenerationResult::logits,
"Score of each sequence (empty if :obj:`return_logits_vocab` was disabled).")
"Logits of each sequence (empty if :obj:`return_logits_vocab` was disabled).")

.def("__repr__", [](const GenerationResult& result) {
return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
Expand Down
2 changes: 1 addition & 1 deletion python/cpp/translation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace ctranslate2 {
.def_readonly("attention", &TranslationResult::attention,
"Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).")
.def_readonly("logits", &TranslationResult::logits,
"Score of each translation hypothesis (empty if :obj:`return_logits_vocab` was disabled).")
"Logits of each translation hypothesis (empty if :obj:`return_logits_vocab` was disabled).")

.def("__repr__", [](const TranslationResult& result) {
return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses)))
Expand Down
2 changes: 2 additions & 0 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ namespace ctranslate2 {
"Generated sequences of token IDs.")
.def_readonly("scores", &models::WhisperGenerationResult::scores,
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
.def_readonly("logits", &models::WhisperGenerationResult::logits,
"logits in each sequence (empty if :obj:`return_logits_vocab` was disabled).")
.def_readonly("no_speech_prob", &models::WhisperGenerationResult::no_speech_prob,
"Probability of the no speech token (0 if :obj:`return_no_speech_prob` was disabled).")

Expand Down
31 changes: 19 additions & 12 deletions src/decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ namespace ctranslate2 {
static inline void sort_hypotheses(DecodingResult& result,
size_t max_hypotheses,
bool keep_scores,
bool keep_attention) {
bool keep_attention,
bool keep_logits_vocab) {
std::vector<size_t> idx(result.hypotheses.size());
std::iota(idx.begin(), idx.end(), 0);
std::sort(idx.begin(), idx.end(),
Expand All @@ -226,14 +227,20 @@ namespace ctranslate2 {
result.attention = index_vector(result.attention, idx);
else
result.attention.clear();

if (keep_logits_vocab)
result.logits_vocab = index_vector(result.logits_vocab, idx);
else
result.logits_vocab.clear();
}

static inline void finalize_result(DecodingResult& result,
const size_t max_hypotheses,
const float length_penalty,
const float coverage_penalty,
const bool keep_scores,
const bool keep_attention) {
const bool keep_attention,
const bool keep_logits_vocab) {
for (size_t i = 0; i < result.scores.size(); ++i) {
const auto* attention = result.attention.empty() ? nullptr : &result.attention[i];
result.scores[i] = finalize_hypothesis_score(result.scores[i],
Expand All @@ -243,7 +250,7 @@ namespace ctranslate2 {
attention);
}

sort_hypotheses(result, max_hypotheses, keep_scores, keep_attention);
sort_hypotheses(result, max_hypotheses, keep_scores, keep_attention, keep_logits_vocab);
}

BiasedDecoder::BiasedDecoder(const float prefix_bias_beta,
Expand Down Expand Up @@ -520,7 +527,7 @@ namespace ctranslate2 {
disable_tokens.apply();
std::vector<StorageView> logits_vec;
if (return_logits_vocab)
logits_vec = build_logits(logits, cur_batch_size);
logits_vec = build_logits(logits, cur_batch_size * _beam_size);

StorageView log_probs(dtype, device);
if (bias_towards_prefix) {
Expand Down Expand Up @@ -602,11 +609,6 @@ namespace ctranslate2 {
auto& result = results[batch_id];
dim_t secondary_candidates_offset = _beam_size;

if (return_logits_vocab) {
results[batch_id].logits_vocab.resize(1);
results[batch_id].logits_vocab[0].emplace_back(std::move(logits_vec[i]));
}

for (dim_t k = 0; k < _beam_size; ++k) {
const size_t last_id = topk_ids.at<int32_t>({i, k});
dim_t next_beam_id = k;
Expand All @@ -624,6 +626,9 @@ namespace ctranslate2 {
result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, start, end));
if (alive_attention)
result.attention.emplace_back(build_attention(alive_attention, i, k, start, end));
if (return_logits_vocab) {
result.logits_vocab.emplace_back(std::move(logits_vec[i * k]));
}

// Move another active beam to this position.
for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) {
Expand Down Expand Up @@ -653,7 +658,8 @@ namespace ctranslate2 {
_length_penalty,
_coverage_penalty,
return_scores,
return_attention);
return_attention,
return_logits_vocab);
} else {
non_finished_index.emplace_back(i);
}
Expand Down Expand Up @@ -798,7 +804,7 @@ namespace ctranslate2 {
}

for (auto& result : final_results)
sort_hypotheses(result, num_hypotheses, return_scores, return_attention);
sort_hypotheses(result, num_hypotheses, return_scores, return_attention, return_logits_vocab);

return final_results;
}
Expand Down Expand Up @@ -934,7 +940,8 @@ namespace ctranslate2 {
_length_penalty,
_coverage_penalty,
return_scores,
return_attention);
return_attention,
return_logits_vocab);
} else {
non_finished_index.emplace_back(i);
sample_from.at<int32_t>(i) = word_id;
Expand Down

0 comments on commit 6ebddf3

Please sign in to comment.