diff --git a/Cargo.lock b/Cargo.lock index 30a3e49..3b81d4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "llm-ls" -version = "0.1.0" +version = "0.1.1" dependencies = [ "home", "reqwest", diff --git a/crates/llm-ls/Cargo.toml b/crates/llm-ls/Cargo.toml index 87809a3..ee8f022 100644 --- a/crates/llm-ls/Cargo.toml +++ b/crates/llm-ls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "llm-ls" -version = "0.1.1" +version = "0.2.0" edition = "2021" [[bin]] diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 44b838b..3c5b2d4 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -19,6 +19,14 @@ use tracing_subscriber::EnvFilter; const NAME: &str = "llm-ls"; const VERSION: &str = env!("CARGO_PKG_VERSION"); +#[derive(Debug, Deserialize, Serialize)] +#[serde(untagged)] +enum TokenizerConfig { + Local { path: PathBuf }, + HuggingFace { repository: String }, + Download { url: String, to: PathBuf }, +} + #[derive(Clone, Debug, Deserialize, Serialize)] struct RequestParams { max_new_tokens: u32, @@ -120,9 +128,9 @@ struct Completion { generated_text: String, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] -enum IDE { +enum Ide { Neovim, VSCode, JetBrains, @@ -130,26 +138,21 @@ enum IDE { Jupyter, Sublime, VisualStudio, + #[default] Unknown, } -impl Default for IDE { - fn default() -> Self { - IDE::Unknown - } -} - -impl Display for IDE { +impl Display for Ide { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.serialize(f) } } -fn parse_ide<'de, D>(d: D) -> std::result::Result +fn parse_ide<'de, D>(d: D) -> std::result::Result where D: Deserializer<'de>, { - Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown)) + Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) } #[derive(Debug, Deserialize, Serialize)] @@ -159,12 +162,12 @@ struct CompletionParams { request_params: RequestParams, #[serde(default)] #[serde(deserialize_with = "parse_ide")] - ide: IDE, + ide: Ide, fim: FimParams, api_token: Option, model: String, tokens_to_clear: Vec, - tokenizer_path: Option, + tokenizer_config: Option, context_window: usize, tls_skip_verify_insecure: bool, } @@ -183,7 +186,7 @@ fn build_prompt( pos: Position, text: &Rope, fim: &FimParams, - tokenizer: Arc, + tokenizer: Option>, context_window: usize, ) -> Result { let t = Instant::now(); @@ -206,10 +209,14 @@ fn build_prompt( while before_line.is_some() || after_line.is_some() { if let Some(before_line) = before_line { let before_line = before_line.to_string(); - let tokens = tokenizer - .encode(before_line.clone(), false) - .map_err(internal_error)? - .len(); + let tokens = if let Some(tokenizer) = tokenizer.clone() { + tokenizer + .encode(before_line.clone(), false) + .map_err(internal_error)? + .len() + } else { + before_line.len() + }; if tokens > token_count { break; } @@ -218,10 +225,14 @@ fn build_prompt( } if let Some(after_line) = after_line { let after_line = after_line.to_string(); - let tokens = tokenizer - .encode(after_line.clone(), false) - .map_err(internal_error)? - .len(); + let tokens = if let Some(tokenizer) = tokenizer.clone() { + tokenizer + .encode(after_line.clone(), false) + .map_err(internal_error)? + .len() + } else { + after_line.len() + }; if tokens > token_count { break; } @@ -253,10 +264,14 @@ fn build_prompt( first = false; } let line = line.to_string(); - let tokens = tokenizer - .encode(line.clone(), false) - .map_err(internal_error)? - .len(); + let tokens = if let Some(tokenizer) = tokenizer.clone() { + tokenizer + .encode(line.clone(), false) + .map_err(internal_error)? + .len() + } else { + line.len() + }; if tokens > token_count { break; } @@ -272,7 +287,7 @@ fn build_prompt( async fn request_completion( http_client: &reqwest::Client, - ide: IDE, + ide: Ide, model: &str, request_params: RequestParams, api_token: Option<&String>, @@ -311,9 +326,10 @@ fn parse_generations(generations: Vec, tokens_to_clear: &[String]) - async fn download_tokenizer_file( http_client: &reqwest::Client, - model: &str, + url: &str, api_token: Option<&String>, to: impl AsRef, + ide: Ide, ) -> Result<()> { if to.as_ref().exists() { return Ok(()); @@ -325,13 +341,9 @@ async fn download_tokenizer_file( ) .await .map_err(internal_error)?; - let mut req = http_client.get(format!( - "https://huggingface.co/{model}/resolve/main/tokenizer.json" - )); - if let Some(api_token) = api_token { - req = req.header(AUTHORIZATION, format!("Bearer {api_token}")) - } - let res = req + let res = http_client + .get(url) + .headers(build_headers(api_token, ide)?) .send() .await .map_err(internal_error)? @@ -352,27 +364,37 @@ async fn download_tokenizer_file( async fn get_tokenizer( model: &str, tokenizer_map: &mut HashMap>, - tokenizer_path: Option<&String>, + tokenizer_config: Option, http_client: &reqwest::Client, cache_dir: impl AsRef, api_token: Option<&String>, -) -> Result> { + ide: Ide, +) -> Result>> { if let Some(tokenizer) = tokenizer_map.get(model) { - return Ok(tokenizer.clone()); + return Ok(Some(tokenizer.clone())); } - let tokenizer = if model.starts_with("http://") || model.starts_with("https://") { - match tokenizer_path { - Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?), - None => return Err(internal_error("`tokenizer_path` is null")), - } + if let Some(config) = tokenizer_config { + let tokenizer = match config { + TokenizerConfig::Local { path } => { + Arc::new(Tokenizer::from_file(path).map_err(internal_error)?) + } + TokenizerConfig::HuggingFace { repository } => { + let path = cache_dir.as_ref().join(model).join("tokenizer.json"); + let url = + format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); + download_tokenizer_file(http_client, &url, api_token, &path, ide).await?; + Arc::new(Tokenizer::from_file(path).map_err(internal_error)?) + } + TokenizerConfig::Download { url, to } => { + download_tokenizer_file(http_client, &url, api_token, &to, ide).await?; + Arc::new(Tokenizer::from_file(to).map_err(internal_error)?) + } + }; + tokenizer_map.insert(model.to_owned(), tokenizer.clone()); + Ok(Some(tokenizer)) } else { - let path = cache_dir.as_ref().join(model).join("tokenizer.json"); - download_tokenizer_file(http_client, model, api_token, &path).await?; - Arc::new(Tokenizer::from_file(path).map_err(internal_error)?) - }; - - tokenizer_map.insert(model.to_owned(), tokenizer.clone()); - Ok(tokenizer) + Ok(None) + } } fn build_url(model: &str) -> String { @@ -394,10 +416,11 @@ impl Backend { let tokenizer = get_tokenizer( ¶ms.model, &mut *self.tokenizer_map.write().await, - params.tokenizer_path.as_ref(), + params.tokenizer_config, &self.http_client, &self.cache_dir, params.api_token.as_ref(), + params.ide, ) .await?; let prompt = build_prompt( @@ -508,7 +531,7 @@ impl LanguageServer for Backend { } } -fn build_headers(api_token: Option<&String>, ide: IDE) -> Result { +fn build_headers(api_token: Option<&String>, ide: Ide) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); headers.insert( diff --git a/crates/mock_server/src/main.rs b/crates/mock_server/src/main.rs index d5b3c1d..4569679 100644 --- a/crates/mock_server/src/main.rs +++ b/crates/mock_server/src/main.rs @@ -1,7 +1,10 @@ use axum::{extract::State, http::HeaderMap, routing::post, Json, Router}; use serde::{Deserialize, Serialize}; use std::{net::SocketAddr, sync::Arc}; -use tokio::sync::Mutex; +use tokio::{ + sync::Mutex, + time::{sleep, Duration}, +}; #[derive(Clone)] struct AppState { @@ -41,6 +44,16 @@ async fn log_headers(headers: HeaderMap, state: State) -> Json) -> Json { + let mut lock = state.counter.lock().await; + *lock += 1; + sleep(Duration::from_millis(200)).await; + println!("waited for req {}", lock); + Json(GeneratedText { + generated_text: "dummy".to_owned(), + }) +} + #[tokio::main] async fn main() { let app_state = AppState { @@ -50,11 +63,12 @@ async fn main() { .route("/", post(default)) .route("/tgi", post(tgi)) .route("/headers", post(log_headers)) + .route("/wait", post(wait)) .with_state(app_state); let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242) .parse() .expect("string to parse to socket addr"); - println!("starting server {}:{}", addr.ip().to_string(), addr.port(),); + println!("starting server {}:{}", addr.ip(), addr.port(),); axum::Server::bind(&addr) .serve(app.into_make_service())