From 3fba2b5fc44f5c4b1963b0088018a25dd74ab2e9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 3 Nov 2024 17:11:12 +0100 Subject: [PATCH] Add the SmolLM2 models. (#2595) * Add the SmolLM2 models. * More SmolLM2 support. --- candle-examples/examples/llama/main.rs | 57 ++++++++++++++----- candle-examples/examples/quantized/main.rs | 25 +++++++- .../src/models/quantized_llama.rs | 9 ++- 3 files changed, 73 insertions(+), 18 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index cc99b6c191..99077b35e9 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -43,6 +43,18 @@ enum Which { Solar10_7B, #[value(name = "tiny-llama-1.1b-chat")] TinyLlama1_1BChat, + #[value(name = "SmoLM2-1.7B")] + SmolLM2_1B, + #[value(name = "SmoLM2-1.7B-Instruct")] + SmolLM2_1BInstruct, + #[value(name = "SmoLM2-360M")] + SmolLM2_360M, + #[value(name = "SmoLM2-360M-Instruct")] + SmolLM2_360MInstruct, + #[value(name = "SmoLM2-135M")] + SmolLM2_135M, + #[value(name = "SmoLM2-135M-Instruct")] + SmolLM2_135MInstruct, } #[derive(Parser, Debug)] @@ -134,19 +146,28 @@ fn main() -> Result<()> { }; let (llama, tokenizer_filename, mut cache, config) = { let api = Api::new()?; - let model_id = args.model_id.unwrap_or_else(|| match args.which { - Which::V1 => "Narsil/amall-7b".to_string(), - Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), - Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(), - Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), - Which::V31 => "meta-llama/Llama-3.1-8B".to_string(), - Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct".to_string(), - Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(), - Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(), - Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(), - Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(), - Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), - Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), + let model_id = args.model_id.unwrap_or_else(|| { + let str = match args.which { + Which::V1 => "Narsil/amall-7b", + Which::V2 => "meta-llama/Llama-2-7b-hf", + Which::V3 => "meta-llama/Meta-Llama-3-8B", + Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct", + Which::V31 => "meta-llama/Llama-3.1-8B", + Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct", + Which::V32_1b => "meta-llama/Llama-3.2-1B", + Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct", + Which::V32_3b => "meta-llama/Llama-3.2-3B", + Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct", + Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0", + Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M", + Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct", + Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M", + Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct", + Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B", + Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct", + }; + str.to_string() }); println!("loading the model weights from {model_id}"); let revision = args.revision.unwrap_or("main".to_string()); @@ -169,7 +190,15 @@ fn main() -> Result<()> { | Which::Solar10_7B => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } - Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => { + Which::SmolLM2_360M + | Which::SmolLM2_360MInstruct + | Which::SmolLM2_135M + | Which::SmolLM2_135MInstruct + | Which::SmolLM2_1B + | Which::SmolLM2_1BInstruct + | Which::V32_1b + | Which::V32_1bInstruct + | Which::TinyLlama1_1BChat => { vec![api.get("model.safetensors")?] } }; diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index d91701ff8b..2b537aac9e 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -71,6 +71,10 @@ enum Which { L8b, #[value(name = "phi3")] Phi3, + #[value(name = "SmoLM2-360M-Instruct")] + SmolLM2_360MInstruct, + #[value(name = "SmoLM2-1.7B-Instruct")] + SmolLM2_1BInstruct, } impl Which { @@ -88,7 +92,9 @@ impl Which { | Self::Leo7b | Self::Leo13b | Self::L8b - | Self::Phi3 => false, + | Self::Phi3 + | Self::SmolLM2_1BInstruct + | Self::SmolLM2_360MInstruct => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 @@ -124,6 +130,8 @@ impl Which { | Self::OpenChat35 | Self::Starling7bAlpha | Self::L8b + | Self::SmolLM2_1BInstruct + | Self::SmolLM2_360MInstruct | Self::Phi3 => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } @@ -150,6 +158,8 @@ impl Which { | Self::Zephyr7bAlpha | Self::Zephyr7bBeta | Self::L8b + | Self::SmolLM2_1BInstruct + | Self::SmolLM2_360MInstruct | Self::Phi3 => false, Self::OpenChat35 | Self::Starling7bAlpha => true, } @@ -179,6 +189,8 @@ impl Which { Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", Self::L8b => "meta-llama/Meta-Llama-3-8B", Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct", + Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct", + Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct", } } } @@ -343,6 +355,14 @@ impl Args { "microsoft/Phi-3-mini-4k-instruct-gguf", "Phi-3-mini-4k-instruct-q4.gguf", ), + Which::SmolLM2_360MInstruct => ( + "HuggingFaceTB/SmolLM2-360M-Instruct-GGUF", + "smollm2-360m-instruct-q8_0.gguf", + ), + Which::SmolLM2_1BInstruct => ( + "HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF", + "smollm2-1.7b-instruct-q4_k_m.gguf", + ), }; let revision = if self.which == Which::Phi3 { "5eef2ce24766d31909c0b269fe90c817a8f263fb" @@ -455,6 +475,8 @@ fn main() -> anyhow::Result<()> { | Which::Leo7b | Which::Leo13b | Which::L8b + | Which::SmolLM2_1BInstruct + | Which::SmolLM2_360MInstruct | Which::Phi3 => 1, Which::Mixtral | Which::MixtralInstruct @@ -573,6 +595,7 @@ fn main() -> anyhow::Result<()> { } let eos_token = match args.which { + Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>", Which::L8b => "<|end_of_text|>", _ => match args.which.is_open_chat() { true => "<|end_of_turn|>", diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 6b326fbe92..20363aeab7 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -351,13 +351,16 @@ impl ModelWeights { let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; - let tok_embeddings = tok_embeddings.dequantize(device)?; + let tok_embeddings_q = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings_q.dequantize(device)?; let norm = RmsNorm::from_qtensor( ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps, )?; - let output = ct.tensor(reader, "output.weight", device)?; + let output = match ct.tensor(reader, "output.weight", device) { + Ok(tensor) => tensor, + Err(_) => tok_embeddings_q, + }; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}");