Skip to content

Commit

Permalink
Add support for encoder-only T5 models (#8900)
Browse files Browse the repository at this point in the history
* gguf-py : add T5ENCODER model architecture

* common : call llama_decode() during warmup only if the model has decoder

* convert-hf : add T5EncoderModel

* llama : add llama_model_has_decoder() API function

* llama : split build_t5() into build_t5_encoder() and build_t5_decoder()

* llama : add support for LLM_ARCH_T5ENCODER

* llama-embedding : add support for LLAMA_POOLING_TYPE_NONE

* llama-embedding : add support for encoder-only models

---------

Co-authored-by: Stanisław Szymczyk <[email protected]>
  • Loading branch information
fairydreaming and sszymczy authored Aug 10, 2024
1 parent 911b437 commit 7c3f55c
Show file tree
Hide file tree
Showing 6 changed files with 649 additions and 282 deletions.
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2156,7 +2156,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
tmp.clear();
tmp.push_back(decoder_start_token_id);
}
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
}
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
llama_reset_timings(lctx);
Expand Down
139 changes: 139 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,145 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@Model.register("T5EncoderModel")
class T5EncoderModel(Model):
model_arch = gguf.MODEL_ARCH.T5ENCODER

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shared_token_embeddings_found = False

def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model

tokenizer_path = self.dir_model / 'tokenizer.model'

# many older models use spiece.model tokenizer model filename
if not tokenizer_path.is_file():
tokenizer_path = self.dir_model / 'spiece.model'

if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")

sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())

# some models like Pile-T5 family use BPE tokenizer instead of Unigram
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
# assure the tokenizer model file name is correct
assert tokenizer_path.name == 'tokenizer.model'
return self._set_vocab_sentencepiece()
else:
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM

add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap

tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))

vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size

for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.IsUnknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.IsControl(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.IsUnused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype

added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)
for key in added_tokens_json:
token_id = added_tokens_json[key]
if token_id >= vocab_size:
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue

tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED

if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
for i in range(1, pad_count + 1):
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)

self.gguf_writer.add_tokenizer_model("t5")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_space_prefix(add_prefix)
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
if precompiled_charsmap:
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)

def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
n_ctx = 512
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
# and decoder and ignore the remaining ones.
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
if not self.shared_token_embeddings_found:
name = "shared.weight"
self.shared_token_embeddings_found = True
else:
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
return []

return [(self.map_tensor_name(name), data_torch)]


@Model.register("JAISLMHeadModel")
class JaisModel(Model):
model_arch = gguf.MODEL_ARCH.JAIS
Expand Down
140 changes: 98 additions & 42 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,47 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}

static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);

// run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
fprintf(stderr, "%s : failed to decode\n", __func__);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode(ctx, batch) < 0) {
fprintf(stderr, "%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode(ctx, batch) < 0) {
fprintf(stderr, "%s : failed to decode\n", __func__);
}
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}

// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
const float * embd = nullptr;
int embd_pos = 0;

if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
embd_pos = i;
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}

float * out = output + batch.seq_id[i][0] * n_embd;
float * out = output + embd_pos * n_embd;
llama_embd_normalize(embd, out, n_embd, embd_norm);
}
}
Expand Down Expand Up @@ -93,8 +115,9 @@ int main(int argc, char ** argv) {
const int n_ctx = llama_n_ctx(ctx);

const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);

if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__);
return 1;
}

Expand Down Expand Up @@ -153,13 +176,23 @@ int main(int argc, char ** argv) {
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);

// count number of embeddings
int n_embd_count = 0;
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int k = 0; k < n_prompts; k++) {
n_embd_count += inputs[k].size();
}
} else {
n_embd_count = n_prompts;
}

// allocate output
const int n_embd = llama_n_embd(model);
std::vector<float> embeddings(n_prompts * n_embd, 0);
std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data();

// break into batches
int p = 0; // number of prompts processed already
int e = 0; // number of embeddings already stored
int s = 0; // number of prompts in current batch
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
Expand All @@ -169,11 +202,11 @@ int main(int argc, char ** argv) {

// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
llama_batch_clear(batch);
p += s;
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
s = 0;
llama_batch_clear(batch);
}

// add to batch
Expand All @@ -182,39 +215,62 @@ int main(int argc, char ** argv) {
}

// final batch
float * out = emb + p * n_embd;
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);

if (params.embd_out.empty()) {
// print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
if (params.embd_normalize == 0) {
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
} else {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);

if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int j = 0; j < n_embd_count; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < std::min(3, n_embd); i++) {
if (params.embd_normalize == 0) {
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
} else {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
}
fprintf(stdout, " ... ");
for (int i = n_embd - 3; i < n_embd; i++) {
if (params.embd_normalize == 0) {
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
} else {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
}
fprintf(stdout, "\n");
}
fprintf(stdout, "\n");
}

// print cosine similarity matrix
if (n_prompts > 1) {
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
fprintf(stdout, "%6.6s ", prompts[i].c_str());
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
if (params.embd_normalize == 0) {
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
} else {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
}
fprintf(stdout, "\n");
}
fprintf(stdout, "\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);

// print cosine similarity matrix
if (n_prompts > 1) {
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
fprintf(stdout, "%6.6s ", prompts[i].c_str());
}
fprintf(stdout, "%1.10s", prompts[i].c_str());
fprintf(stdout, "\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
fprintf(stdout, "%1.10s", prompts[i].c_str());
fprintf(stdout, "\n");
}
}
}
}
Expand All @@ -233,23 +289,23 @@ int main(int argc, char ** argv) {
}
fprintf(stdout, notArray ? "]\n }" : "]");
j++;
if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break;
}
fprintf(stdout, notArray ? "\n ]" : "]\n");

if (params.embd_out == "json+" && n_prompts > 1) {
fprintf(stdout, ",\n \"cosineSimilarity\": [\n");
for (int i = 0;;) { // at least two iteration (n_prompts > 1)
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
fprintf(stdout, " [");
for (int j = 0;;) { // at least two iteration (n_prompts > 1)
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f", sim);
j++;
if (j < n_prompts) fprintf(stdout, ", "); else break;
if (j < n_embd_count) fprintf(stdout, ", "); else break;
}
fprintf(stdout, " ]");
i++;
if (i < n_prompts) fprintf(stdout, ",\n"); else break;
if (i < n_embd_count) fprintf(stdout, ",\n"); else break;
}
fprintf(stdout, "\n ]");
}
Expand Down
Loading

0 comments on commit 7c3f55c

Please sign in to comment.