diff --git a/candle-examples/examples/nvembed_v2/README.md b/candle-examples/examples/nvembed_v2/README.md new file mode 100644 index 000000000..66b10fab0 --- /dev/null +++ b/candle-examples/examples/nvembed_v2/README.md @@ -0,0 +1,43 @@ +# NV-Embed-v2 + +Candle implementation (inference only) of [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2), a text embedding model that ranks No. 1 (as of Nov 25 2024) on the [MTEB](https://huggingface.co/spaces/mteb/leaderboard) benchmark with a score of 72.31 across 56 text embedding tasks. + +## Running an example: Retrieval +```bash +cargo run --example nvembed_v2 --release +> scores: [[87.4269, 0.4629], +> [ 0.9653, 86.0372]] +> Tensor[[2, 2], f32] +``` +In this example, we have two queries and two passages (the corresponding answers). The output tensor represents the similarity scores between each query-passage pair. The scores are computed by taking the dot product of the query and passage embeddings and scaling the result by 100. +```rust +let queries = [ + "are judo throws allowed in wrestling?", + "how to become a radiology technician in michigan?", +]; +let query_instruction = + "Instruct: Given a question, retrieve passages that answer the question\nQuery: " + .to_string(); + +let passages = [ + "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.", + "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan." +]; +let passage_instruction = "".to_string(); +``` + +If you already have the model and tokenizer files, you can use the `--tokenizer` and `--model-files` options to specify their full paths, instead of downloading them from the hub. + +## Running an example: Sentence embedding +```bash +cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence" +> Embedding: [[ 0.0066, -0.0048, 0.0066, ..., -0.0096, 0.0119, -0.0052]] +> Tensor[[1, 4096], f32] +``` +In this example, we pass a prompt to the model and it outputs the vector encoding of the prompt. + +## Hardware Requirements +29.25GB at fp32 + +## License +CC-BY-NC-4.0. This model should not be used for any commercial purpose. Refer the [license](https://spdx.org/licenses/CC-BY-NC-4.0) for the detailed terms. diff --git a/candle-examples/examples/nvembed_v2/main.rs b/candle-examples/examples/nvembed_v2/main.rs new file mode 100644 index 000000000..8db9a100f --- /dev/null +++ b/candle-examples/examples/nvembed_v2/main.rs @@ -0,0 +1,214 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use candle::{DType, IndexOp, Shape, Tensor, D}; +use candle_nn::VarBuilder; +use candle_transformers::models::nvembed_v2::model::Model; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{PaddingDirection, PaddingParams, Tokenizer, TruncationParams}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, + + #[arg(long)] + tokenizer: Option, + + #[arg(long)] + model: Option, + + /// Comma-separated list of model files (e.g., '/path/file1.safetensors,/path/file2.safetensors,/path/file3.safetensors') + #[arg(long)] + model_files: Option, +} + +impl Args { + fn build_model_and_tokenizer(&self) -> anyhow::Result<(Model, tokenizers::Tokenizer)> { + let model_name = match self.model.as_ref() { + Some(model) => model.to_string(), + None => "nvidia/NV-Embed-v2".to_string(), + }; + + let api = Api::new()?; + let repo = api.repo(Repo::new(model_name.to_string(), RepoType::Model)); + + let model_files = match &self.model_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + + let tokenizer_file = match &self.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + let device = candle_examples::device(self.cpu)?; + + let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; + + let _ = tokenizer + .with_padding(Some(PaddingParams { + direction: PaddingDirection::Right, + pad_id: 2, + pad_token: "".to_string(), + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length: 32768, + ..Default::default() + })); + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device) }?; + + let nvembed_model = Model::new(vb); + Ok((nvembed_model?, tokenizer)) + } +} + +fn encode( + model: &mut Model, + tokenizer: &Tokenizer, + examples: Vec, + instruction: &str, +) -> Result { + let device = &model.device; + let dtype = model.dtype; + + // Format input text + let eos_token = if let Some(padding) = tokenizer.get_padding() { + padding.pad_token.clone() + } else { + "".to_string() + }; + let bos = "".to_string(); + let input_texts = examples + .iter() + .map(|input_example| format!("{bos}{instruction}{input_example}{eos_token}")) + .collect::>(); + + // Tokenize + let encodings = tokenizer.encode_batch(input_texts, false).map_err(E::msg)?; + + let input_ids_list = encodings + .iter() + .map(|encoding| { + Tensor::from_slice( + encoding.get_ids(), + Shape::from(encoding.get_ids().len()), + device, + ) + }) + .collect::, _>>()?; + let input_ids = Tensor::stack(&input_ids_list, 0)?; + + // Mask out padding tokens for both embedding model and latent attention model + let attention_masks: Vec = encodings + .iter() + .map(|encoding| { + Tensor::from_slice( + encoding.get_attention_mask(), + Shape::from(encoding.get_attention_mask().len()), + device, + )? + .to_dtype(dtype) + }) + .collect::, _>>()?; + let attention_mask = Tensor::stack(&attention_masks, 0)?; + + // Mask out instruction tokens for latent attention model + let pool_mask = if !instruction.is_empty() { + let encoded_instruction = tokenizer.encode(instruction, false).map_err(E::msg)?; + let instruction_lens = encoded_instruction.get_tokens().len(); + let zeros = Tensor::zeros( + attention_mask.i((.., ..instruction_lens))?.shape(), + dtype, + device, + )?; + let b = attention_mask.dims()[0]; + attention_mask.slice_assign(&[..b, ..instruction_lens], &zeros)? + } else { + attention_mask.clone() + }; + + let hiddens = model + .forward(&input_ids, &attention_mask, &pool_mask)? + .squeeze(1)?; + + // Normalize embedding + div_l2_norm(&hiddens) +} + +fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + Ok(v.broadcast_div(&l2_norm)?) +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let (mut model, tokenizer) = args.build_model_and_tokenizer()?; + + if let Some(prompt) = args.prompt { + let emb = encode(&mut model, &tokenizer, vec![prompt], "")?; + println!("Embedding: {emb}"); + } else { + let queries = [ + "are judo throws allowed in wrestling?", + "how to become a radiology technician in michigan?", + ]; + + let passages = [ + "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.", + "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan." + ]; + let passage_instruction = "".to_string(); + let query_instruction = + "Instruct: Given a question, retrieve passages that answer the question\nQuery: " + .to_string(); + + let passages: Vec = passages.iter().map(|s| s.to_string()).collect(); + let queries: Vec = queries.iter().map(|s| s.to_string()).collect(); + + let emb_query = encode(&mut model, &tokenizer, queries, &query_instruction)?; + let emb_passage = encode(&mut model, &tokenizer, passages, &passage_instruction)?; + + let scores = (emb_query.matmul(&emb_passage.t()?)? * 100.0)?; + + println!("scores: {scores}"); + } + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 571a88614..be1f15c41 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -62,6 +62,7 @@ pub mod mobilenetv4; pub mod mobileone; pub mod moondream; pub mod mpt; +pub mod nvembed_v2; pub mod olmo; pub mod openclip; pub mod paligemma; diff --git a/candle-transformers/src/models/nvembed_v2/embedding.rs b/candle-transformers/src/models/nvembed_v2/embedding.rs new file mode 100644 index 000000000..a52192afd --- /dev/null +++ b/candle-transformers/src/models/nvembed_v2/embedding.rs @@ -0,0 +1,294 @@ +/// Mistral LLM, https://github.com/mistralai/mistral-src +use crate::models::{ + mistral::Config, + with_tracing::{linear_no_bias, Linear, RmsNorm}, +}; +use crate::utils::repeat_kv; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let rope_theta = cfg.rope_theta as f32; + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let key_states = repeat_kv(key_states, self.num_kv_groups)?; + let value_states = repeat_kv(value_states, self.num_kv_groups)?; + + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&value_states)?; + + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + pub cfg: Config, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?; + Ok(Self { + embed_tokens, + layers, + norm, + cfg: cfg.clone(), + }) + } + + // Attn mask used to mask out padding tokens + pub fn forward( + &mut self, + attn_mask: &Tensor, + input_ids: &Tensor, + dtype: DType, + ) -> Result { + let mut xs = self.embed_tokens.forward(input_ids)?; + + // Expand to 4d mask for sdpa + let attn_mask = prepare_4d_attention_mask(attn_mask, dtype, None)?; + + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, Some(&attn_mask), 0)?; + } + + // Return hiddens instead of logits + xs.apply(&self.norm) + } +} + +fn prepare_4d_attention_mask( + mask: &Tensor, + dtype: DType, + tgt_len: Option, +) -> Result { + let bsz = mask.dims()[0]; + let src_len = mask.dims()[1]; + let tgt_len = tgt_len.unwrap_or(src_len); + + let expanded_mask = mask + .unsqueeze(1)? + .unsqueeze(2)? + .expand((bsz, 1, tgt_len, src_len))? + .to_dtype(dtype)?; + + let inverted_mask = (1.0 - expanded_mask)?; + + (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype) +} + +fn get_dtype_min_val(dtype: DType) -> f64 { + match dtype { + DType::F32 => f32::MIN as f64, + DType::F64 => f64::MIN, + _ => panic!("Unsupported data type"), + } +} diff --git a/candle-transformers/src/models/nvembed_v2/mod.rs b/candle-transformers/src/models/nvembed_v2/mod.rs new file mode 100644 index 000000000..8a8f70078 --- /dev/null +++ b/candle-transformers/src/models/nvembed_v2/mod.rs @@ -0,0 +1,18 @@ +//! NV-Embed-v2 +//! +//! NV-Embed-v2 is a text embedding model that combines a Mistral decoder with a latent attention mechanism to produce high-quality text embeddings. +//! +//! This implementation is based on the [paper](https://arxiv.org/pdf/2405.17428) and [weights](https://huggingface.co/nvidia/NV-Embed-v2) +//! +//! # Query-Passage Retrieval Example +//! ```bash +//! cargo run --example nvembed_v2 --release +//! ``` +//! +//! # Sentence Embedding Example +//! ```bash +//! cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence" +//! ``` + +pub mod embedding; +pub mod model; diff --git a/candle-transformers/src/models/nvembed_v2/model.rs b/candle-transformers/src/models/nvembed_v2/model.rs new file mode 100644 index 000000000..73ef776e3 --- /dev/null +++ b/candle-transformers/src/models/nvembed_v2/model.rs @@ -0,0 +1,233 @@ +use super::embedding::Model as EmbeddingModel; +use crate::models::{ + mistral::Config, + with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear}, +}; +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder}; + +// Geglu and feedforward from candle-transformers/src/models/stable_diffusion/attention.rs +#[derive(Debug)] +struct GeGlu { + proj: Linear, + span: tracing::Span, +} + +impl GeGlu { + fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result { + let proj = linear(dim_in, dim_out * 2, vs)?; + let span = tracing::span!(tracing::Level::TRACE, "geglu"); + Ok(Self { proj, span }) + } +} + +impl Module for GeGlu { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; + &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()? + } +} + +#[derive(Debug)] +struct FeedForward { + project_in: GeGlu, + linear: Linear, + span: tracing::Span, +} + +impl FeedForward { + fn new(vs: VarBuilder, dim: usize, dim_out: Option, mult: usize) -> Result { + let inner_dim = dim * mult; + let dim_out = dim_out.unwrap_or(dim); + let vs = vs.pp("net"); + let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?; + let linear = linear(inner_dim, dim_out, vs.pp("2"))?; + let span = tracing::span!(tracing::Level::TRACE, "ff"); + Ok(Self { + project_in, + linear, + span, + }) + } +} + +impl Module for FeedForward { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let xs = self.project_in.forward(xs)?; + self.linear.forward(&xs) + } +} + +// CrossAttention from candle-transformers/src/models/stable_diffusion/attention.rs +#[derive(Debug)] +struct CrossAttention { + to_q: Linear, + to_kv: Linear, + to_out: Linear, + heads: usize, + scale: f64, + span: tracing::Span, + span_attn: tracing::Span, + span_softmax: tracing::Span, +} + +impl CrossAttention { + fn new( + vs: VarBuilder, + query_dim: usize, + context_dim: Option, + heads: usize, + dim_head: usize, + ) -> Result { + let inner_dim = dim_head * heads; + let context_dim = context_dim.unwrap_or(query_dim); + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?; + let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?; + let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?; + let span = tracing::span!(tracing::Level::TRACE, "xa"); + let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax"); + Ok(Self { + to_q, + to_kv, + to_out, + heads, + scale, + span, + span_attn, + span_softmax, + }) + } + + fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))? + .transpose(1, 2)? + .reshape((batch_size * self.heads, seq_len, dim / self.heads)) + } + + fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))? + .transpose(1, 2)? + .reshape((batch_size / self.heads, seq_len, dim * self.heads)) + } + + fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result { + let _enter = self.span_attn.enter(); + + let in_dtype = query.dtype(); + let query = query.to_dtype(DType::F32)?; + let key = key.to_dtype(DType::F32)?; + let value = value.to_dtype(DType::F32)?; + let xs = query.matmul(&(key.t()? * self.scale)?)?; + let xs = { + let _enter = self.span_softmax.enter(); + softmax_last_dim(&xs)? + }; + let xs = xs.matmul(&value)?.to_dtype(in_dtype)?; + + self.reshape_batch_dim_to_heads(&xs) + } + + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let query = self.to_q.forward(xs)?; + let context = context.unwrap_or(xs).contiguous()?; + let kv_chunks = self + .to_kv + .forward(&context)? + .chunk(2, context.shape().dims().len() - 1)?; + let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone()); + let query = self.reshape_heads_to_batch_dim(&query)?; + let key = self.reshape_heads_to_batch_dim(&key)?; + let value = self.reshape_heads_to_batch_dim(&value)?; + + let xs = self.attention(&query, &key, &value)?; + self.to_out.forward(&xs) + } +} + +#[derive(Debug)] +pub struct Model { + embedding_model: EmbeddingModel, + cross_attn: CrossAttention, + cross_attn_norm: LayerNorm, + cross_attn_context_norm: LayerNorm, + ff: FeedForward, + ff_norm: LayerNorm, + latents: Tensor, + pub device: Device, + pub dtype: DType, +} + +impl Model { + pub fn new(vb: VarBuilder) -> Result { + // Embedding model + let cfg = Config::config_7b_v0_1(false); + let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?; + + // Latent attention + let dim = 4096; + let vb = vb.pp("latent_attention_model"); + let latents = vb.get((512, dim), "latents")?; + + // Cross attend blocks + let vb = vb.pp("cross_attend_blocks"); + let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?; + let cross_attn_context_norm = layer_norm( + dim, + candle_nn::LayerNormConfig::default(), + vb.pp("0.norm_context"), + )?; + let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?; + + let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?; + let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?; + + Ok(Self { + embedding_model, + cross_attn, + cross_attn_norm, + cross_attn_context_norm, + ff, + ff_norm, + latents, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + attn_mask: &Tensor, + pool_mask: &Tensor, + ) -> Result { + // Embedding model + let hiddens = self + .embedding_model + .forward(attn_mask, input_ids, self.dtype)?; + + // Latent attention + let b = hiddens.dims()[0]; + let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?; + let original_hiddens = &hiddens; + + let hiddens = self.cross_attn_norm.forward(original_hiddens)?; + let x = self.cross_attn_context_norm.forward(&x)?; + let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?; + + let hiddens = self.ff_norm.forward(&cross_hiddens)?; + let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?; + + // Mean pooling + let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?; + let s = hiddens_masked.sum(1)?; + let d = pool_mask.sum_keepdim(1)?; + s.broadcast_div(&d) + } +}