From 60c8e80f591953f1338d32e3dadebc45a44d2802 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 17 Aug 2024 12:30:28 -0700 Subject: [PATCH 1/2] Custom actions, server shutdown fixes and a bunch of small things --- Cargo.lock | 5 +- crates/lsp-ai/Cargo.toml | 1 + crates/lsp-ai/src/config.rs | 42 ++- crates/lsp-ai/src/main.rs | 27 +- .../lsp-ai/src/memory_backends/file_store.rs | 25 +- crates/lsp-ai/src/memory_backends/mod.rs | 42 +-- .../src/memory_backends/postgresml/mod.rs | 17 +- .../src/memory_backends/vector_store.rs | 18 +- crates/lsp-ai/src/memory_worker.rs | 13 +- .../src/transformer_backends/anthropic.rs | 51 ++-- .../lsp-ai/src/transformer_backends/gemini.rs | 4 +- .../src/transformer_backends/mistral_fim.rs | 27 +- .../lsp-ai/src/transformer_backends/ollama.rs | 103 ++++--- .../src/transformer_backends/open_ai/mod.rs | 111 ++++--- crates/lsp-ai/src/transformer_worker.rs | 281 ++++++++++++++---- crates/lsp-ai/src/utils.rs | 24 +- crates/lsp-ai/tests/integration_tests.rs | 252 +++++++++++++++- 17 files changed, 793 insertions(+), 250 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6dfe424..da300eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1590,6 +1590,7 @@ dependencies = [ "pgml", "rand", "rayon", + "regex", "reqwest", "ropey", "serde", @@ -2265,9 +2266,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", diff --git a/crates/lsp-ai/Cargo.toml b/crates/lsp-ai/Cargo.toml index e2f1222..088e5b3 100644 --- a/crates/lsp-ai/Cargo.toml +++ b/crates/lsp-ai/Cargo.toml @@ -43,6 +43,7 @@ fxhash = "0.2.1" ordered-float = "4.2.1" futures = "0.3" clap = { version = "4.5.14", features = ["derive"] } +regex = "1.10.6" [build-dependencies] cc="1" diff --git a/crates/lsp-ai/src/config.rs b/crates/lsp-ai/src/config.rs index 763ede1..89abfa4 100644 --- a/crates/lsp-ai/src/config.rs +++ b/crates/lsp-ai/src/config.rs @@ -9,15 +9,23 @@ const fn max_requests_per_second_default() -> f32 { 1. } +const fn true_default() -> bool { + true +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct PostProcess { + pub extractor: Option, + #[serde(default = "true_default")] pub remove_duplicate_start: bool, + #[serde(default = "true_default")] pub remove_duplicate_end: bool, } impl Default for PostProcess { fn default() -> Self { Self { + extractor: None, remove_duplicate_start: true, remove_duplicate_end: true, } @@ -353,13 +361,31 @@ pub struct Chat { pub(crate) parameters: Kwargs, } +#[derive(Clone, Debug, Deserialize)] +pub struct Action { + // The name to display in the editor + pub(crate) action_display_name: String, + // The model key to use + pub(crate) model: String, + // Args are deserialized by the backend using them + #[serde(default)] + pub(crate) parameters: Kwargs, + // Parameters for post processing + #[serde(default)] + pub(crate) post_process: PostProcess, +} + #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct ValidConfig { pub(crate) memory: ValidMemoryBackend, pub(crate) models: HashMap, pub(crate) completion: Option, - pub(crate) chat: Option>, + #[serde(default)] + pub(crate) actions: Vec, + #[serde(default)] + #[serde(alias = "chat")] // Legacy from when it was called chat, remove soon + pub(crate) chats: Vec, } #[derive(Clone, Debug, Deserialize, Default)] @@ -396,8 +422,12 @@ impl Config { // Helpers for the backends /////////// /////////////////////////////////////// - pub fn get_chat(&self) -> Option<&Vec> { - self.config.chat.as_ref() + pub fn get_chats(&self) -> &Vec { + &self.config.chats + } + + pub fn get_actions(&self) -> &Vec { + &self.config.actions } pub fn is_completions_enabled(&self) -> bool { @@ -446,7 +476,8 @@ impl Config { memory: ValidMemoryBackend::FileStore(FileStore { crawl: None }), models: HashMap::new(), completion: None, - chat: None, + actions: vec![], + chats: vec![], }, client_params: ValidClientParams { root_uri: None }, } @@ -458,7 +489,8 @@ impl Config { memory: ValidMemoryBackend::VectorStore(vector_store), models: HashMap::new(), completion: None, - chat: None, + actions: vec![], + chats: vec![], }, client_params: ValidClientParams { root_uri: None }, } diff --git a/crates/lsp-ai/src/main.rs b/crates/lsp-ai/src/main.rs index e7419eb..9141b42 100644 --- a/crates/lsp-ai/src/main.rs +++ b/crates/lsp-ai/src/main.rs @@ -3,7 +3,7 @@ use clap::Parser; use directories::BaseDirs; use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId}; use lsp_types::{ - request::{CodeActionRequest, CodeActionResolveRequest, Completion}, + request::{CodeActionRequest, CodeActionResolveRequest, Completion, Shutdown}, CodeActionOptions, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, }; @@ -127,7 +127,10 @@ fn main() -> Result<()> { })?; let initialization_args = connection.initialize(server_capabilities)?; - main_loop(connection, initialization_args)?; + if let Err(e) = main_loop(connection, initialization_args) { + error!("{e:?}"); + } + io_threads.join()?; Ok(()) } @@ -147,7 +150,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { // Setup the transformer worker let memory_backend: Box = config.clone().try_into()?; - thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); + let memory_worker_thread = thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); // Setup our transformer worker let transformer_backends: HashMap> = config @@ -160,7 +163,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { let thread_connection = connection.clone(); let thread_memory_tx = memory_tx.clone(); let thread_config = config.clone(); - thread::spawn(move || { + let transformer_worker_thread = thread::spawn(move || { transformer_worker::run( transformer_backends, thread_memory_tx, @@ -173,10 +176,18 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { for msg in &connection.receiver { match msg { Message::Request(req) => { - if connection.handle_shutdown(&req)? { + if request_is::(&req) { + memory_tx.send(memory_worker::WorkerRequest::Shutdown)?; + if let Err(e) = memory_worker_thread.join() { + std::panic::resume_unwind(e) + } + transformer_tx.send(WorkerRequest::Shutdown)?; + if let Err(e) = transformer_worker_thread.join() { + std::panic::resume_unwind(e) + } + connection.handle_shutdown(&req)?; return Ok(()); - } - if request_is::(&req) { + } else if request_is::(&req) { match cast::(req) { Ok((id, params)) => { let completion_request = CompletionRequest::new(id, params); @@ -224,7 +235,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { Err(err) => error!("{err:?}"), } } else { - error!("Unsupported command - see the wiki for a list of supported commands") + error!("Unsupported command - see the wiki for a list of supported commands: {req:?}") } } Message::Notification(not) => { diff --git a/crates/lsp-ai/src/memory_backends/file_store.rs b/crates/lsp-ai/src/memory_backends/file_store.rs index 65aaad4..28a9243 100644 --- a/crates/lsp-ai/src/memory_backends/file_store.rs +++ b/crates/lsp-ai/src/memory_backends/file_store.rs @@ -240,20 +240,22 @@ impl FileStore { let rope_slice = rope .get_slice(start..end + "".chars().count()) .context("Error getting rope slice")?; - Prompt::ContextAndCode(ContextAndCodePrompt::new( - "".to_string(), - rope_slice.to_string(), - )) + Prompt::ContextAndCode(ContextAndCodePrompt { + context: "".to_string(), + code: rope_slice.to_string(), + selected_text: None, + }) } else { let start = cursor_index .saturating_sub(tokens_to_estimated_characters(params.max_context)); let rope_slice = rope .get_slice(start..cursor_index) .context("Error getting rope slice")?; - Prompt::ContextAndCode(ContextAndCodePrompt::new( - "".to_string(), - rope_slice.to_string(), - )) + Prompt::ContextAndCode(ContextAndCodePrompt { + context: "".to_string(), + code: rope_slice.to_string(), + selected_text: None, + }) } } PromptType::FIM => { @@ -268,7 +270,10 @@ impl FileStore { let suffix = rope .get_slice(cursor_index..end) .context("Error getting rope slice")?; - Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string())) + Prompt::FIM(FIMPrompt { + prompt: prefix.to_string(), + suffix: suffix.to_string(), + }) } }) } @@ -837,8 +842,6 @@ mod tests { #[test] fn test_file_store_tree_sitter() -> anyhow::Result<()> { - crate::init_logger(); - let config = Config::default_with_file_store_without_models(); let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) = config.config.memory.clone() diff --git a/crates/lsp-ai/src/memory_backends/mod.rs b/crates/lsp-ai/src/memory_backends/mod.rs index 085753e..126649b 100644 --- a/crates/lsp-ai/src/memory_backends/mod.rs +++ b/crates/lsp-ai/src/memory_backends/mod.rs @@ -36,12 +36,7 @@ impl From<&Value> for MemoryRunParams { pub struct ContextAndCodePrompt { pub context: String, pub code: String, -} - -impl ContextAndCodePrompt { - pub fn new(context: String, code: String) -> Self { - Self { context, code } - } + pub selected_text: Option, } #[derive(Debug)] @@ -50,15 +45,6 @@ pub struct FIMPrompt { pub suffix: String, } -impl FIMPrompt { - pub fn new(prefix: String, suffix: String) -> Self { - Self { - prompt: prefix, - suffix, - } - } -} - #[derive(Debug)] pub enum Prompt { FIM(FIMPrompt), @@ -159,23 +145,25 @@ impl TryFrom for Box { #[cfg(test)] impl Prompt { pub fn default_with_cursor() -> Self { - Self::ContextAndCode(ContextAndCodePrompt::new( - r#"def test_context():\n pass"#.to_string(), - r#"def test_code():\n "#.to_string(), - )) + Self::ContextAndCode(ContextAndCodePrompt { + context: r#"def test_context():\n pass"#.to_string(), + code: r#"def test_code():\n "#.to_string(), + selected_text: None, + }) } pub fn default_fim() -> Self { - Self::FIM(FIMPrompt::new( - r#"def test_context():\n pass"#.to_string(), - r#"def test_code():\n "#.to_string(), - )) + Self::FIM(FIMPrompt { + prompt: r#"def test_context():\n pass"#.to_string(), + suffix: r#"def test_code():\n "#.to_string(), + }) } pub fn default_without_cursor() -> Self { - Self::ContextAndCode(ContextAndCodePrompt::new( - r#"def test_context():\n pass"#.to_string(), - r#"def test_code():\n "#.to_string(), - )) + Self::ContextAndCode(ContextAndCodePrompt { + context: r#"def test_context():\n pass"#.to_string(), + code: r#"def test_code():\n "#.to_string(), + selected_text: None, + }) } } diff --git a/crates/lsp-ai/src/memory_backends/postgresml/mod.rs b/crates/lsp-ai/src/memory_backends/postgresml/mod.rs index a006bf1..a003d7c 100644 --- a/crates/lsp-ai/src/memory_backends/postgresml/mod.rs +++ b/crates/lsp-ai/src/memory_backends/postgresml/mod.rs @@ -589,19 +589,20 @@ impl MemoryBackend for PostgresML { // Reconstruct the Prompts Ok(match code { Prompt::ContextAndCode(context_and_code) => { - Prompt::ContextAndCode(ContextAndCodePrompt::new( - context.to_owned(), - format_file_chunk( + Prompt::ContextAndCode(ContextAndCodePrompt { + context: context.to_owned(), + code: format_file_chunk( position.text_document.uri.as_ref(), &context_and_code.code, self.config.client_params.root_uri.as_deref(), ), - )) + selected_text: None, + }) } - Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new( - format!("{context}\n\n{}", fim.prompt), - fim.suffix, - )), + Prompt::FIM(fim) => Prompt::FIM(FIMPrompt { + prompt: format!("{context}\n\n{}", fim.prompt), + suffix: fim.suffix, + }), }) } diff --git a/crates/lsp-ai/src/memory_backends/vector_store.rs b/crates/lsp-ai/src/memory_backends/vector_store.rs index 024463c..9d6ed2e 100644 --- a/crates/lsp-ai/src/memory_backends/vector_store.rs +++ b/crates/lsp-ai/src/memory_backends/vector_store.rs @@ -726,19 +726,20 @@ impl MemoryBackend for VectorStore { // Reconstruct the prompts Ok(match code { Prompt::ContextAndCode(context_and_code) => { - Prompt::ContextAndCode(ContextAndCodePrompt::new( - context.to_owned(), - format_file_chunk( + Prompt::ContextAndCode(ContextAndCodePrompt { + context: context.to_owned(), + code: format_file_chunk( position.text_document.uri.as_ref(), &context_and_code.code, self.config.client_params.root_uri.as_deref(), ), - )) + selected_text: None, + }) } - Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new( - format!("{context}\n\n{}", fim.prompt), - fim.suffix, - )), + Prompt::FIM(fim) => Prompt::FIM(FIMPrompt { + prompt: format!("{context}\n\n{}", fim.prompt), + suffix: fim.suffix, + }), }) } } @@ -935,7 +936,6 @@ assert multiply_two_numbers(2, 3) == 6 #[tokio::test] async fn can_build_prompt() -> anyhow::Result<()> { - crate::init_logger(); let text_document1 = generate_filler_text_document(None, None); let params = lsp_types::DidOpenTextDocumentParams { text_document: text_document1.clone(), diff --git a/crates/lsp-ai/src/memory_worker.rs b/crates/lsp-ai/src/memory_worker.rs index f22efa8..d91acd4 100644 --- a/crates/lsp-ai/src/memory_worker.rs +++ b/crates/lsp-ai/src/memory_worker.rs @@ -94,6 +94,7 @@ impl FileRequest { } pub(crate) enum WorkerRequest { + Shutdown, FilterText(FilterRequest), File(FileRequest), Prompt(PromptRequest), @@ -160,6 +161,7 @@ fn do_task( memory_backend.changed_text_document(params)?; } WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params)?, + WorkerRequest::Shutdown => unreachable!(), } anyhow::Ok(()) } @@ -171,8 +173,15 @@ fn do_run( let memory_backend = Arc::new(memory_backend); loop { let request = rx.recv()?; - if let Err(e) = do_task(request, memory_backend.clone()) { - error!("error in memory worker task: {e}") + match &request { + WorkerRequest::Shutdown => { + return Ok(()); + } + _ => { + if let Err(e) = do_task(request, memory_backend.clone()) { + error!("error in memory worker task: {e}") + } + } } } } diff --git a/crates/lsp-ai/src/transformer_backends/anthropic.rs b/crates/lsp-ai/src/transformer_backends/anthropic.rs index e7ea0dd..aab123e 100644 --- a/crates/lsp-ai/src/transformer_backends/anthropic.rs +++ b/crates/lsp-ai/src/transformer_backends/anthropic.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use anyhow::Context; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::{info, instrument}; @@ -32,6 +32,7 @@ const fn temperature_default() -> f32 { #[derive(Debug, Deserialize)] pub(crate) struct AnthropicRunParams { system: String, + #[serde(default)] messages: Vec, #[serde(default = "max_tokens_default")] pub(crate) max_tokens: usize, @@ -45,18 +46,27 @@ pub(crate) struct Anthropic { config: config::Anthropic, } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize)] +struct AnthropicResponse { + content: Vec, +} + +#[derive(Deserialize, Serialize)] struct AnthropicChatMessage { text: String, } -#[derive(Deserialize)] -struct AnthropicChatResponse { - content: Option>, - error: Option, - #[serde(default)] - #[serde(flatten)] - pub(crate) other: HashMap, +#[derive(Deserialize, Serialize)] +pub struct ChatError { + error: Value, +} + +#[derive(Deserialize, Serialize)] +#[serde(untagged)] +enum ChatResponse { + Success(AnthropicResponse), + Error(ChatError), + Other(HashMap), } impl Anthropic { @@ -92,7 +102,7 @@ impl Anthropic { "Calling Anthropic compatible API with parameters:\n{}", serde_json::to_string_pretty(¶ms).unwrap() ); - let res: AnthropicChatResponse = client + let res: ChatResponse = client .post( self.config .chat_endpoint @@ -108,15 +118,18 @@ impl Anthropic { .await? .json() .await?; - if let Some(error) = res.error { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(mut content) = res.content { - Ok(std::mem::take(&mut content[0].text)) - } else { - anyhow::bail!( - "Uknown error while making request to Anthropic: {:?}", - res.other - ) + info!( + "Response from Anthropic compatible API:\n{}", + serde_json::to_string_pretty(&res).unwrap() + ); + match res { + ChatResponse::Success(mut resp) => Ok(std::mem::take(&mut resp.content[0].text)), + ChatResponse::Error(error) => { + anyhow::bail!("making Anthropic request: {:?}", error.error.to_string()) + } + ChatResponse::Other(other) => { + anyhow::bail!("unknown error while making Anthropic request: {:?}", other) + } } } diff --git a/crates/lsp-ai/src/transformer_backends/gemini.rs b/crates/lsp-ai/src/transformer_backends/gemini.rs index db1c550..ae23fa2 100644 --- a/crates/lsp-ai/src/transformer_backends/gemini.rs +++ b/crates/lsp-ai/src/transformer_backends/gemini.rs @@ -10,7 +10,7 @@ use crate::{ transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, }, - utils::format_context_code_in_str, + utils::format_prompt_in_str, }; fn format_gemini_contents( @@ -25,7 +25,7 @@ fn format_gemini_contents( m.parts .iter() .map(|p| Part { - text: format_context_code_in_str(&p.text, &prompt.context, &prompt.code), + text: format_prompt_in_str(&p.text, &prompt), }) .collect(), ) diff --git a/crates/lsp-ai/src/transformer_backends/mistral_fim.rs b/crates/lsp-ai/src/transformer_backends/mistral_fim.rs index fe22644..0aec5f8 100644 --- a/crates/lsp-ai/src/transformer_backends/mistral_fim.rs +++ b/crates/lsp-ai/src/transformer_backends/mistral_fim.rs @@ -97,15 +97,24 @@ impl MistralFIM { .await? .json() .await?; - if let Some(error) = res.error { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(choices) = res.choices { - Ok(choices[0].message.content.clone()) - } else { - anyhow::bail!( - "Unknown error while making request to MistralFIM: {:?}", - res.other - ); + + info!( + "Response from Mistral compatible FIM API:\n{}", + serde_json::to_string_pretty(&res).unwrap() + ); + match res { + OpenAIChatResponse::Success(mut resp) => { + Ok(std::mem::take(&mut resp.choices[0].message.content)) + } + OpenAIChatResponse::Error(error) => { + anyhow::bail!("making Mistral FIM request: {:?}", error.error.to_string()) + } + OpenAIChatResponse::Other(other) => { + anyhow::bail!( + "unknown error while making Mistral FIM request: {:?}", + other + ) + } } } } diff --git a/crates/lsp-ai/src/transformer_backends/ollama.rs b/crates/lsp-ai/src/transformer_backends/ollama.rs index d003e2b..678b486 100644 --- a/crates/lsp-ai/src/transformer_backends/ollama.rs +++ b/crates/lsp-ai/src/transformer_backends/ollama.rs @@ -9,7 +9,7 @@ use crate::{ transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, }, - utils::{format_chat_messages, format_context_code}, + utils::{format_chat_messages, format_prompt}, }; use super::TransformerBackend; @@ -30,13 +30,17 @@ pub(crate) struct Ollama { configuration: config::Ollama, } -#[derive(Deserialize)] -struct OllamaCompletionsResponse { - response: Option, - error: Option, - #[serde(default)] - #[serde(flatten)] - other: HashMap, +#[derive(Deserialize, Serialize)] +struct OllamaValidCompletionsResponse { + response: String, +} + +#[derive(Deserialize, Serialize)] +#[serde(untagged)] +enum OllamaCompletionsResponse { + Success(OllamaValidCompletionsResponse), + Error(OllamaError), + Other(HashMap), } #[derive(Debug, Deserialize, Serialize)] @@ -45,13 +49,22 @@ struct OllamaChatMessage { content: String, } -#[derive(Deserialize)] -struct OllamaChatResponse { - message: Option, - error: Option, - #[serde(default)] - #[serde(flatten)] - other: HashMap, +#[derive(Deserialize, Serialize)] +struct OllamaError { + error: Value, +} + +#[derive(Deserialize, Serialize)] +struct OllamaValidChatResponse { + message: OllamaChatMessage, +} + +#[derive(Deserialize, Serialize)] +#[serde(untagged)] +enum OllamaChatResponse { + Success(OllamaValidChatResponse), + Error(OllamaError), + Other(HashMap), } impl Ollama { @@ -75,7 +88,7 @@ impl Ollama { "stream": false }); info!( - "Calling Ollama compatible completion API with parameters:\n{}", + "Calling Ollama compatible completions API with parameters:\n{}", serde_json::to_string_pretty(¶ms).unwrap() ); let res: OllamaCompletionsResponse = client @@ -92,15 +105,24 @@ impl Ollama { .await? .json() .await?; - if let Some(error) = res.error { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(response) = res.response { - Ok(response) - } else { - anyhow::bail!( - "Uknown error while making request to Ollama: {:?}", - res.other - ) + info!( + "Response from Ollama compatible completions API:\n{}", + serde_json::to_string_pretty(&res).unwrap() + ); + match res { + OllamaCompletionsResponse::Success(mut resp) => Ok(std::mem::take(&mut resp.response)), + OllamaCompletionsResponse::Error(error) => { + anyhow::bail!( + "making Ollama completions request: {:?}", + error.error.to_string() + ) + } + OllamaCompletionsResponse::Other(other) => { + anyhow::bail!( + "unknown error while making Ollama completions request: {:?}", + other + ) + } } } @@ -137,15 +159,21 @@ impl Ollama { .await? .json() .await?; - if let Some(error) = res.error { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(message) = res.message { - Ok(message.content) - } else { - anyhow::bail!( - "Unknown error while making request to Ollama: {:?}", - res.other - ) + info!( + "Response from Ollama compatible chat API:\n{}", + serde_json::to_string_pretty(&res).unwrap() + ); + match res { + OllamaChatResponse::Success(mut resp) => Ok(std::mem::take(&mut resp.message.content)), + OllamaChatResponse::Error(error) => { + anyhow::bail!("making Ollama chat request: {:?}", error.error.to_string()) + } + OllamaChatResponse::Other(other) => { + anyhow::bail!( + "unknown error while making Ollama chat request: {:?}", + other + ) + } } } @@ -161,11 +189,8 @@ impl Ollama { self.get_chat(messages, params).await } None => { - self.get_completion( - &format_context_code(&code_and_context.context, &code_and_context.code), - params, - ) - .await + self.get_completion(&format_prompt(&code_and_context), params) + .await } }, Prompt::FIM(fim) => match ¶ms.fim { diff --git a/crates/lsp-ai/src/transformer_backends/open_ai/mod.rs b/crates/lsp-ai/src/transformer_backends/open_ai/mod.rs index a23669c..0f45254 100644 --- a/crates/lsp-ai/src/transformer_backends/open_ai/mod.rs +++ b/crates/lsp-ai/src/transformer_backends/open_ai/mod.rs @@ -11,7 +11,7 @@ use crate::{ transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, }, - utils::{format_chat_messages, format_context_code}, + utils::{format_chat_messages, format_prompt}, }; use super::TransformerBackend; @@ -57,18 +57,27 @@ pub(crate) struct OpenAI { configuration: config::OpenAI, } -#[derive(Deserialize)] -struct OpenAICompletionsChoice { +#[derive(Deserialize, Serialize)] +pub(crate) struct OpenAICompletionsChoice { text: String, } -#[derive(Deserialize)] -struct OpenAICompletionsResponse { - choices: Option>, - error: Option, - #[serde(default)] - #[serde(flatten)] - pub(crate) other: HashMap, +#[derive(Deserialize, Serialize)] +pub(crate) struct OpenAIError { + pub(crate) error: Value, +} + +#[derive(Deserialize, Serialize)] +pub(crate) struct OpenAIValidCompletionsResponse { + pub(crate) choices: Vec, +} + +#[derive(Deserialize, Serialize)] +#[serde(untagged)] +pub(crate) enum OpenAICompletionsResponse { + Success(OpenAIValidCompletionsResponse), + Error(OpenAIError), + Other(HashMap), } #[derive(Debug, Deserialize, Serialize)] @@ -77,18 +86,22 @@ pub(crate) struct OpenAIChatMessage { pub(crate) content: String, } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize)] pub(crate) struct OpenAIChatChoices { pub(crate) message: OpenAIChatMessage, } -#[derive(Deserialize)] -pub(crate) struct OpenAIChatResponse { - pub(crate) choices: Option>, - pub(crate) error: Option, - #[serde(default)] - #[serde(flatten)] - pub(crate) other: HashMap, +#[derive(Deserialize, Serialize)] +pub(crate) struct OpenAIValidChatResponse { + pub(crate) choices: Vec, +} + +#[derive(Deserialize, Serialize)] +#[serde(untagged)] +pub(crate) enum OpenAIChatResponse { + Success(OpenAIValidChatResponse), + Error(OpenAIError), + Other(HashMap), } impl OpenAI { @@ -144,15 +157,26 @@ impl OpenAI { .json(¶ms) .send().await? .json().await?; - if let Some(error) = res.error { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(mut choices) = res.choices { - Ok(std::mem::take(&mut choices[0].text)) - } else { - anyhow::bail!( - "Uknown error while making request to OpenAI: {:?}", - res.other - ) + info!( + "Response from OpenAI compatible completions API:\n{}", + serde_json::to_string_pretty(&res).unwrap() + ); + match res { + OpenAICompletionsResponse::Success(mut resp) => { + Ok(std::mem::take(&mut resp.choices[0].text)) + } + OpenAICompletionsResponse::Error(error) => { + anyhow::bail!( + "making OpenAI completions request: {:?}", + error.error.to_string() + ) + } + OpenAICompletionsResponse::Other(other) => { + anyhow::bail!( + "unknown error while making OpenAI completions request: {:?}", + other + ) + } } } @@ -192,15 +216,23 @@ impl OpenAI { .await? .json() .await?; - if let Some(error) = res.error { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(choices) = res.choices { - Ok(choices[0].message.content.clone()) - } else { - anyhow::bail!( - "Unknown error while making request to OpenAI: {:?}", - res.other - ) + info!( + "Response from OpenAI compatible chat API:\n{}", + serde_json::to_string_pretty(&res).unwrap() + ); + match res { + OpenAIChatResponse::Success(mut resp) => { + Ok(std::mem::take(&mut resp.choices[0].message.content)) + } + OpenAIChatResponse::Error(error) => { + anyhow::bail!("making OpenAI chat request: {:?}", error.error.to_string()) + } + OpenAIChatResponse::Other(other) => { + anyhow::bail!( + "unknown error while making OpenAI chat request: {:?}", + other + ) + } } } @@ -216,11 +248,8 @@ impl OpenAI { self.get_chat(messages, params).await } None => { - self.get_completion( - &format_context_code(&code_and_context.context, &code_and_context.code), - params, - ) - .await + self.get_completion(&format_prompt(&code_and_context), params) + .await } }, Prompt::FIM(fim) => match ¶ms.fim { diff --git a/crates/lsp-ai/src/transformer_worker.rs b/crates/lsp-ai/src/transformer_worker.rs index d52e3e3..94f29bd 100644 --- a/crates/lsp-ai/src/transformer_worker.rs +++ b/crates/lsp-ai/src/transformer_worker.rs @@ -5,13 +5,17 @@ use lsp_types::{ CompletionParams, CompletionResponse, Position, Range, TextDocumentIdentifier, TextDocumentPositionParams, TextEdit, WorkspaceEdit, }; +use once_cell::sync::Lazy; +use parking_lot::Mutex; +use regex::Regex; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::sync::mpsc::RecvTimeoutError; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::{ + collections::HashMap, + sync::{mpsc::RecvTimeoutError, Arc}, + time::{Duration, SystemTime}, +}; use tokio::sync::oneshot; -use tracing::{error, instrument}; +use tracing::{error, info, instrument}; use crate::config::{self, Config}; use crate::custom_requests::generation::{GenerateResult, GenerationParams}; @@ -21,6 +25,8 @@ use crate::memory_worker::{self, FileRequest, FilterRequest, PromptRequest}; use crate::transformer_backends::TransformerBackend; use crate::utils::{ToResponseError, TOKIO_RUNTIME}; +static RE: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); + #[derive(Clone, Debug)] pub(crate) struct CompletionRequest { id: RequestId, @@ -85,6 +91,7 @@ impl CodeActionResolveRequest { #[derive(Clone, Debug)] pub(crate) enum WorkerRequest { + Shutdown, Completion(CompletionRequest), Generation(GenerationRequest), GenerationStream(GenerationStreamRequest), @@ -95,6 +102,7 @@ pub(crate) enum WorkerRequest { impl WorkerRequest { fn get_id(&self) -> RequestId { match self { + WorkerRequest::Shutdown => unreachable!(), WorkerRequest::Completion(r) => r.id.clone(), WorkerRequest::Generation(r) => r.id.clone(), WorkerRequest::GenerationStream(r) => r.id.clone(), @@ -166,6 +174,27 @@ fn post_process_response( ) -> String { match prompt { Prompt::ContextAndCode(context_and_code) => { + // First we need to extract + let response = if let Some(extractor) = &config.extractor { + let mut re_map = RE.lock(); + let re = match re_map.get(extractor) { + Some(re) => re, + None => { + let re = Regex::new(extractor).unwrap(); + re_map.insert(extractor.to_owned(), re); + re_map.get(extractor).unwrap() + } + }; + let response = re + .captures(&response) + .and_then(|cap| cap.get(1)) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(); + info!("response text after extracting:\n{}", response); + response + } else { + response + }; if context_and_code.code.contains("") { let mut split = context_and_code.code.split(""); let response = if config.remove_duplicate_start { @@ -254,6 +283,9 @@ fn do_run( match request { Ok(request) => match &request { + WorkerRequest::Shutdown => { + return Ok(()); + } WorkerRequest::Completion(completion_request) => { if max_requests_per_second.is_ok() { last_completion_request = Some(request); @@ -362,6 +394,7 @@ async fn generate_response( WorkerRequest::CodeActionResolveRequest(request) => { do_code_action_resolve(transformer_backends, memory_backend_tx, &request, &config).await } + WorkerRequest::Shutdown => unreachable!(), } } @@ -371,36 +404,18 @@ struct CodeActionResolveData { range: Range, } -// TODO: @silas we need to make this compatible with any llm backend -async fn do_code_action_resolve( +async fn do_chat_code_action_resolve( + action: &config::Chat, transformer_backends: Arc>>, memory_backend_tx: std::sync::mpsc::Sender, request: &CodeActionResolveRequest, - config: &Config, -) -> anyhow::Result { - let chats = match config.get_chat() { - Some(chats) => chats, - None => { - return Ok(Response { - id: request.id.clone(), - result: None, - error: None, - }); - } - }; - let chat = chats - .into_iter() - .find(|chat| chat.action_display_name == request.params.title) - .with_context(|| { - format!( - "could not resolve action with title: {}", - request.params.title - ) - })?; - - let transformer_backend = transformer_backends - .get(&chat.model) - .with_context(|| format!("model: {} not found when resolving code action", chat.model))?; +) -> anyhow::Result { + let transformer_backend = transformer_backends.get(&action.model).with_context(|| { + format!( + "model: {} not found when resolving code action", + action.model + ) + })?; let data: CodeActionResolveData = serde_json::from_value( request @@ -421,14 +436,14 @@ async fn do_code_action_resolve( )))?; let file_text = rx.await?; - let (messages_text, text_edit_line, text_edit_char) = if chat.trigger == "" { + let (messages_text, text_edit_line, text_edit_char) = if action.trigger == "" { ( file_text.as_str(), file_text.lines().count(), file_text.lines().last().unwrap_or("").chars().count(), ) } else { - let mut split = file_text.split(&chat.trigger); + let mut split = file_text.splitn(2, &action.trigger); let text_edit_line = split .next() .context("trigger not found when resolving chat code action")? @@ -486,7 +501,7 @@ async fn do_code_action_resolve( // Add the messages to the params messages // NOTE: Once again we are making some assumptions that the messages key is even the right key to use here - let mut params = chat.parameters.clone(); + let mut params = action.parameters.clone(); if let Some(messages) = params.get_mut("messages") { messages .as_array_mut() @@ -527,19 +542,161 @@ async fn do_code_action_resolve( ); let changes = HashMap::from([(data.text_document.uri, vec![edit])]); + Ok(CodeAction { + title: action.action_display_name.clone(), + edit: Some(WorkspaceEdit { + changes: Some(changes), + ..Default::default() + }), + ..Default::default() + }) +} + +async fn do_code_action_action_resolve( + action: &config::Action, + transformer_backends: Arc>>, + memory_backend_tx: std::sync::mpsc::Sender, + request: &CodeActionResolveRequest, +) -> anyhow::Result { + let transformer_backend = transformer_backends.get(&action.model).with_context(|| { + format!( + "model: {} not found when resolving code action", + action.model + ) + })?; + + let data: CodeActionResolveData = serde_json::from_value( + request + .params + .data + .clone() + .context("the `data` field is required to resolve a code action")?, + ) + .context("the `data` field could not be deserialized when resolving the code action")?; + + let params = serde_json::to_value(action.parameters.clone()).unwrap(); + + // Get the prompt + let text_document_position = TextDocumentPositionParams { + text_document: data.text_document.clone(), + position: data.range.start, + }; + let (tx, rx) = oneshot::channel(); + memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new( + text_document_position, + transformer_backend.get_prompt_type(¶ms)?, + params.clone(), + tx, + )))?; + let mut prompt = rx.await?; + + // If they have some text highlighted and we aren't doing FIM let's get it + if matches!(prompt, Prompt::ContextAndCode(_)) && data.range.start != data.range.end { + // Get the file + let (tx, rx) = oneshot::channel(); + memory_backend_tx.send(memory_worker::WorkerRequest::File(FileRequest::new( + TextDocumentIdentifier { + uri: data.text_document.uri.clone(), + }, + tx, + )))?; + let file_text = rx.await?; + + // Get the text + let lines: Vec<&str> = file_text.lines().collect(); + let mut result = String::new(); + for (i, line) in lines + .iter() + .enumerate() + .skip(data.range.start.line as usize) + .take((data.range.end.line - data.range.start.line + 1) as usize) + { + let start_char = if i == data.range.start.line as usize { + data.range.start.character as usize + } else { + 0 + }; + let end_char = if i == data.range.end.line as usize { + data.range.end.character as usize + 1 + } else { + line.len() + }; + + if start_char < line.len() { + result.push_str(&line[start_char..end_char.min(line.len())]); + } + + if i != data.range.end.line as usize { + result.push('\n'); + } + } + + // Update our prompt to include the selected text + if let Prompt::ContextAndCode(prompt) = &mut prompt { + prompt.selected_text = Some(result) + } + } + + // Get the response + let mut response = transformer_backend.do_completion(&prompt, params).await?; + response.insert_text = + post_process_response(response.insert_text, &prompt, &action.post_process); + + let edit = TextEdit::new( + Range::new( + Position::new(data.range.start.line, data.range.start.character), + Position::new(data.range.end.line, data.range.end.character), + ), + response.insert_text.clone(), + ); + let changes = HashMap::from([(data.text_document.uri, vec![edit])]); + + Ok(CodeAction { + title: action.action_display_name.clone(), + edit: Some(WorkspaceEdit { + changes: Some(changes), + ..Default::default() + }), + ..Default::default() + }) +} + +// TODO: @silas we need to make this compatible with any llm backend +async fn do_code_action_resolve( + transformer_backends: Arc>>, + memory_backend_tx: std::sync::mpsc::Sender, + request: &CodeActionResolveRequest, + config: &Config, +) -> anyhow::Result { + let action = if let Some(chat_action) = config + .get_chats() + .iter() + .find(|chat_action| chat_action.action_display_name == request.params.title) + { + do_chat_code_action_resolve( + chat_action, + transformer_backends, + memory_backend_tx, + request, + ) + .await? + } else { + let action = config + .get_actions() + .iter() + .find(|action| action.action_display_name == request.params.title) + .with_context(|| { + format!( + "action: {} does not exist in `chats` or `actions`", + request.params.title + ) + })?; + do_code_action_action_resolve(action, transformer_backends, memory_backend_tx, request) + .await? + }; Ok(Response { id: request.id.clone(), - result: Some( - serde_json::to_value(CodeAction { - title: chat.action_display_name.clone(), - edit: Some(WorkspaceEdit { - changes: Some(changes), - ..Default::default() - }), - ..Default::default() - }) - .unwrap(), - ), + result: Some(serde_json::to_value(action).unwrap()), error: None, }) } @@ -549,16 +706,8 @@ async fn do_code_action_request( request: &CodeActionRequest, config: &Config, ) -> anyhow::Result { - let chats = match config.get_chat() { - Some(chats) => chats, - None => { - return Ok(Response { - id: request.id.clone(), - result: None, - error: None, - }); - } - }; + let actions = config.get_actions(); + let chats = config.get_chats(); let enabled_chats = futures::future::join_all(chats.iter().map(|chat| async { let (tx, rx) = oneshot::channel(); @@ -578,7 +727,7 @@ async fn do_code_action_request( .into_iter() .collect::>>()?; - let code_actions: Vec = chats + let mut code_actions: Vec = chats .into_iter() .zip(enabled_chats) .filter(|(_, is_enabled)| *is_enabled) @@ -595,6 +744,20 @@ async fn do_code_action_request( }) .collect(); + code_actions.extend(actions.into_iter().map(|action| { + CodeAction { + title: action.action_display_name.to_owned(), + data: Some( + serde_json::to_value(CodeActionResolveData { + text_document: request.params.text_document.clone(), + range: request.params.range, + }) + .unwrap(), + ), + ..Default::default() + } + })); + Ok(Response { id: request.id.clone(), result: Some(serde_json::to_value(&code_actions).unwrap()), @@ -840,6 +1003,7 @@ mod tests { let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { context: "".to_string(), code: "tt ".to_string(), + selected_text: None, }); let response = "tt abc".to_string(); let new_response = post_process_response(response.clone(), &prompt, &config); @@ -848,6 +1012,7 @@ mod tests { let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { context: "".to_string(), code: "ff".to_string(), + selected_text: None, }); let response = "zz".to_string(); let new_response = post_process_response(response.clone(), &prompt, &config); @@ -856,6 +1021,7 @@ mod tests { let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { context: "".to_string(), code: "tt tt".to_string(), + selected_text: None, }); let response = "tt abc tt".to_string(); let new_response = post_process_response(response.clone(), &prompt, &config); @@ -864,6 +1030,7 @@ mod tests { let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { context: "".to_string(), code: "dd".to_string(), + selected_text: None, }); let response = "zz".to_string(); let new_response = post_process_response(response.clone(), &prompt, &config); diff --git a/crates/lsp-ai/src/utils.rs b/crates/lsp-ai/src/utils.rs index e0a988e..b5d6fcd 100644 --- a/crates/lsp-ai/src/utils.rs +++ b/crates/lsp-ai/src/utils.rs @@ -38,21 +38,25 @@ pub(crate) fn format_chat_messages( ) -> Vec { messages .iter() - .map(|m| { - ChatMessage::new( - m.role.to_owned(), - format_context_code_in_str(&m.content, &prompt.context, &prompt.code), - ) - }) + .map(|m| ChatMessage::new(m.role.to_owned(), format_prompt_in_str(&m.content, &prompt))) .collect() } -pub(crate) fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String { - s.replace("{CONTEXT}", context).replace("{CODE}", code) +pub(crate) fn format_prompt_in_str(s: &str, prompt: &ContextAndCodePrompt) -> String { + s.replace("{CONTEXT}", &prompt.context) + .replace("{CODE}", &prompt.code) + .replace( + "{SELECTED_TEXT}", + prompt + .selected_text + .as_ref() + .map(|x| x.as_str()) + .unwrap_or_default(), + ) } -pub(crate) fn format_context_code(context: &str, code: &str) -> String { - format!("{context}\n\n{code}") +pub(crate) fn format_prompt(prompt: &ContextAndCodePrompt) -> String { + format!("{}\n\n{}", &prompt.context, &prompt.code) } pub(crate) fn chunk_to_id(uri: &str, chunk: &Chunk) -> String { diff --git a/crates/lsp-ai/tests/integration_tests.rs b/crates/lsp-ai/tests/integration_tests.rs index b523e12..4b1439a 100644 --- a/crates/lsp-ai/tests/integration_tests.rs +++ b/crates/lsp-ai/tests/integration_tests.rs @@ -42,7 +42,7 @@ fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> { } // This chat completion sequence was created using helix with lsp-ai and reading the logs -// It utilizes Ollama with llama3:8b-instruct-q4_0 and a temperature of 0 +// It utilizes Ollama with llama3.1:8b and a temperature of 0 // It starts with a Python file: // ``` // # Multiplies two numbers @@ -287,3 +287,253 @@ fn test_completion_sequence() -> Result<()> { child.kill()?; Ok(()) } + +// This chat sequence was created using helix with lsp-ai and reading the logs +// It utilizes Ollama with llama3:8b-instruct-q4_0 and a temperature of 0 +// It starts with a Markdown file: +// ``` +// !C Who are +// ``` +// And has the following sequence of key strokes: +// A on line 1 (this enters insert mode at the end of line 1) +// (space) +// y +// o +// u +// (esc) +// (run code action for chatting) +// ge (goes to end of file) +// o (inserts a new line below cursor) +// ? +// (esc) +// (run code action for chatting) +#[test] +fn test_chat_sequence() -> Result<()> { + let mut child = Command::new("cargo") + .arg("run") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + + let mut stdin = child.stdin.take().unwrap(); + let mut stdout = child.stdout.take().unwrap(); + + let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"formatting":{"dynamicRegistration":false},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.7 (0d62656c)"},"initializationOptions":{"chat":[{"action_display_name":"Chat","model":"model1","parameters":{"max_context":4096,"max_tokens":1024,"messages":[{"content":"You are a code assistant chatbot. The user will ask you for assistance coding and you will do you best to answer succinctly and accurately","role":"system"}],"options":{"temperature":0}},"trigger":"!C"}],"memory":{"file_store":{}},"models":{"model1":{"model":"llama3:8b-instruct-q4_0","type":"ollama"}}},"processId":50522,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##; + send_message(&mut stdin, initialization_message)?; + let _ = read_response(&mut stdout)?; + + send_message( + &mut stdin, + r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"markdown","text":"!C Who are\n","uri":"file:///fake.md","version":0}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":10,"line":0},"start":{"character":10,"line":0}},"text":" "}],"textDocument":{"uri":"file:///fake.md","version":1}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":11,"line":0},"start":{"character":11,"line":0}},"text":"y"}],"textDocument":{"uri":"file:///fake.md","version":2}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":12,"line":0},"start":{"character":12,"line":0}},"text":"o"}],"textDocument":{"uri":"file:///fake.md","version":3}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":13,"line":0},"start":{"character":13,"line":0}},"text":"u"}],"textDocument":{"uri":"file:///fake.md","version":4}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/codeAction","params":{"context":{"diagnostics":[],"triggerKind":1},"range":{"end":{"character":0,"line":1},"start":{"character":14,"line":0}},"textDocument":{"uri":"file:///fake.md"}},"id":3}"##, + )?; + + // Test that our action is present + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":3,"result":[{"data":{"range":{"end":{"character":0,"line":1},"start":{"character":14,"line":0}},"text_document":{"uri":"file:///fake.md"}},"title":"Chat"}]}"## + ); + + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"codeAction/resolve","params":{"data":{"range":{"end":{"character":0,"line":1},"start":{"character":14,"line":0}},"text_document":{"uri":"file:///fake.md"}},"title":"Chat"},"id":4}"##, + )?; + + // Test that we get the corret model output + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":4,"result":{"edit":{"changes":{"file:///fake.md":[{"newText":"\n\n<|assistant|>\nI'm CodePal, your friendly AI code assistant! I'm here to help you with any programming-related questions or problems you might have. Whether you're a beginner or an experienced developer, I'll do my best to provide clear and concise answers to get you back on track. What can I help you with today?\n\n<|user|>\n","range":{"end":{"character":12,"line":1},"start":{"character":12,"line":1}}}]}},"title":"Chat"}}"## + ); + + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":0,"line":1},"start":{"character":0,"line":1}},"text":"\n\n<|assistant|>\nI'm CodePal, your friendly AI code assistant! I'm here to help you with any programming-related questions or problems you might have. Whether you're a beginner or an experienced developer, I'll do my best to provide clear and concise answers to get you back on track. What can I help you with today?\n\n<|user|>\n"}],"textDocument":{"uri":"file:///fake.md","version":5}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":8,"line":6},"start":{"character":8,"line":6}},"text":"\n"}],"textDocument":{"uri":"file:///fake.md","version":6}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":0,"line":7},"start":{"character":0,"line":7}},"text":"?"}],"textDocument":{"uri":"file:///fake.md","version":7}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/codeAction","params":{"context":{"diagnostics":[],"triggerKind":1},"range":{"end":{"character":0,"line":8},"start":{"character":1,"line":7}},"textDocument":{"uri":"file:///fake.md"}},"id":5}"##, + )?; + + // Test that our action is present + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":5,"result":[{"data":{"range":{"end":{"character":0,"line":8},"start":{"character":1,"line":7}},"text_document":{"uri":"file:///fake.md"}},"title":"Chat"}]}"## + ); + + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"codeAction/resolve","params":{"data":{"range":{"end":{"character":0,"line":8},"start":{"character":1,"line":7}},"text_document":{"uri":"file:///fake.md"}},"title":"Chat"},"id":6}"##, + )?; + + // Test that we get the correct model output + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":6,"result":{"edit":{"changes":{"file:///fake.md":[{"newText":"\n\n<|assistant|>\nIt seems like you didn't ask a question! That's okay, I'm here to help whenever you're ready. If you have any programming-related questions or need assistance with a specific coding problem, feel free to ask me anything!\n\n<|user|>\n","range":{"end":{"character":1,"line":8},"start":{"character":1,"line":8}}}]}},"title":"Chat"}}"## + ); + + child.kill()?; + Ok(()) +} + +// This custom action completion sequence was created using helix with lsp-ai and reading the logs +// It utilizes Ollama with llama3.1:8b and a temperature of 0 +// It starts with a Python file: +// ``` +// def fib(n): +// ``` +// And has the following sequence of key strokes: +// gl (goes to end of line with cursor on ":") +// l (moves cursor to position after ":") +// (run code action for Completion) +#[test] +fn test_completion_action_sequence() -> Result<()> { + let mut child = Command::new("cargo") + .arg("run") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + + let mut stdin = child.stdin.take().unwrap(); + let mut stdout = child.stdout.take().unwrap(); + + let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"formatting":{"dynamicRegistration":false},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.7 (0d62656c)"},"initializationOptions":{"actions":[{"action_display_name":"Complete","model":"model1","parameters":{"max_context":4096,"max_tokens":4096,"messages":[{"content":"You are an AI coding assistant. Your task is to complete code snippets. The user's cursor position is marked by \"\". Follow these steps:\n\n1. Analyze the code context and the cursor position.\n2. Provide your chain of thought reasoning, wrapped in tags. Include thoughts about the cursor position, what needs to be completed, and any necessary formatting.\n3. Determine the appropriate code to complete the current thought, including finishing partial words or lines.\n4. Replace \"\" with the necessary code, ensuring proper formatting and line breaks.\n5. Wrap your code solution in tags.\n\nYour response should always include both the reasoning and the answer. Pay special attention to completing partial words or lines before adding new lines of code.\n\n\n\nUser input:\n--main.py--\n# A function that reads in user inpu\n\nResponse:\n\n1. The cursor is positioned after \"inpu\" in a comment describing a function that reads user input.\n2. We need to complete the word \"input\" in the comment first.\n3. After completing the comment, we should add a new line before defining the function.\n4. The function should use Python's built-in `input()` function to read user input.\n5. We'll name the function descriptively and include a return statement.\n\n\nt\ndef read_user_input():\n user_input = input(\"Enter your input: \")\n return user_input\n\n\n\n\nUser input:\n--main.py--\ndef fibonacci(n):\n if n <= 1:\n return n\n else:\n re\n\n\nResponse:\n\n1. The cursor is positioned after \"re\" in the 'else' clause of a recursive Fibonacci function.\n2. We need to complete the return statement for the recursive case.\n3. The \"re\" already present likely stands for \"return\", so we'll continue from there.\n4. The Fibonacci sequence is the sum of the two preceding numbers.\n5. We should return the sum of fibonacci(n-1) and fibonacci(n-2).\n\n\nturn fibonacci(n-1) + fibonacci(n-2)\n\n\n","role":"system"},{"content":"{CODE}","role":"user"}],"options":{"temperature":0}},"post_process":{"extractor":"(?s)(.*?)"}}],"memory":{"file_store":{}},"models":{"model1":{"model":"llama3.1:8b","type":"ollama"}}},"processId":55832,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##; + send_message(&mut stdin, initialization_message)?; + let _ = read_response(&mut stdout)?; + + send_message( + &mut stdin, + r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"def fib(n):\n","uri":"file:///fake.py","version":0}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/codeAction","params":{"context":{"diagnostics":[],"triggerKind":1},"range":{"end":{"character":0,"line":1},"start":{"character":11,"line":0}},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##, + )?; + + // Test that our action is present + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":1,"result":[{"data":{"range":{"end":{"character":0,"line":1},"start":{"character":11,"line":0}},"text_document":{"uri":"file:///fake.py"}},"title":"Complete"}]}"## + ); + + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"codeAction/resolve","params":{"data":{"range":{"end":{"character":0,"line":1},"start":{"character":11,"line":0}},"text_document":{"uri":"file:///fake.py"}},"title":"Complete"},"id":2}"##, + )?; + + // Test that we get the corret model output + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":2,"result":{"edit":{"changes":{"file:///fake.py":[{"newText":"\n if n <= 1:\n return n\n else:\n return fib(n-1) + fib(n-2)","range":{"end":{"character":0,"line":1},"start":{"character":11,"line":0}}}]}},"title":"Complete"}}"## + ); + + child.kill()?; + Ok(()) +} + +// This custom action refactor sequence was created using helix with lsp-ai and reading the logs +// It utilizes Ollama with llama3.1:8b and a temperature of 0 +// It starts with a Python file: +// ``` +// # The fibonacci recursively +// def fib(n): +// if n <= 1: +// return n +// a, b = 0, 1 +// for _ in range(2, n + 1): +// a, b = b, a + b +// return b +// ``` +// And has the following sequence of key strokes: +// % (selects all text) +// (run code action for Refactoring) +#[test] +fn test_refactor_action_sequence() -> Result<()> { + let mut child = Command::new("cargo") + .arg("run") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + + let mut stdin = child.stdin.take().unwrap(); + let mut stdout = child.stdout.take().unwrap(); + + let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"formatting":{"dynamicRegistration":false},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.7 (0d62656c)"},"initializationOptions":{"actions":[{"action_display_name":"Refactor","model":"model1","parameters":{"max_context":4096,"max_tokens":4096,"messages":[{"content":"You are an AI coding assistant specializing in code refactoring. Your task is to analyze the given code snippet and provide a refactored version. Follow these steps:\n\n1. Analyze the code context and structure.\n2. Identify areas for improvement, such as code efficiency, readability, or adherence to best practices.\n3. Provide your chain of thought reasoning, wrapped in tags. Include your analysis of the current code and explain your refactoring decisions.\n4. Rewrite the entire code snippet with your refactoring applied.\n5. Wrap your refactored code solution in tags.\n\nYour response should always include both the reasoning and the refactored code.\n\n\n\nUser input:\ndef calculate_total(items):\n total = 0\n for item in items:\n total = total + item['price'] * item['quantity']\n return total\n\n\nResponse:\n\n1. The function calculates the total cost of items based on price and quantity.\n2. We can improve readability and efficiency by:\n a. Using a more descriptive variable name for the total.\n b. Utilizing the sum() function with a generator expression.\n c. Using augmented assignment (+=) if we keep the for loop.\n3. We'll implement the sum() function approach for conciseness.\n4. We'll add a type hint for better code documentation.\n\n\nfrom typing import List, Dict\n\ndef calculate_total(items: List[Dict[str, float]]) -> float:\n return sum(item['price'] * item['quantity'] for item in items)\n\n\n\n\nUser input:\ndef is_prime(n):\n if n < 2:\n return False\n for i in range(2, n):\n if n % i == 0:\n return False\n return True\n\n\nResponse:\n\n1. This function checks if a number is prime, but it's not efficient for large numbers.\n2. We can improve it by:\n a. Adding an early return for 2, the only even prime number.\n b. Checking only odd numbers up to the square root of n.\n c. Using a more efficient range (start at 3, step by 2).\n3. We'll also add a type hint for better documentation.\n4. The refactored version will be more efficient for larger numbers.\n\n\nimport math\n\ndef is_prime(n: int) -> bool:\n if n < 2:\n return False\n if n == 2:\n return True\n if n % 2 == 0:\n return False\n \n for i in range(3, int(math.sqrt(n)) + 1, 2):\n if n % i == 0:\n return False\n return True\n\n\n","role":"system"},{"content":"{SELECTED_TEXT}","role":"user"}],"options":{"temperature":0}},"post_process":{"extractor":"(?s)(.*?)"}}],"memory":{"file_store":{}},"models":{"model1":{"model":"llama3.1:8b","type":"ollama"}}},"processId":56258,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##; + send_message(&mut stdin, initialization_message)?; + let _ = read_response(&mut stdout)?; + + send_message( + &mut stdin, + r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# The fibonacci recursively\ndef fib(n):\n if n <= 1:\n return n\n a, b = 0, 1\n for _ in range(2, n + 1):\n a, b = b, a + b\n return b\n","uri":"file:///fake.py","version":0}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/codeAction","params":{"context":{"diagnostics":[],"triggerKind":1},"range":{"end":{"character":0,"line":8},"start":{"character":0,"line":0}},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##, + )?; + + // Test that our action is present + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":1,"result":[{"data":{"range":{"end":{"character":0,"line":8},"start":{"character":0,"line":0}},"text_document":{"uri":"file:///fake.py"}},"title":"Refactor"}]}"## + ); + + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"codeAction/resolve","params":{"data":{"range":{"end":{"character":0,"line":8},"start":{"character":0,"line":0}},"text_document":{"uri":"file:///fake.py"}},"title":"Refactor"},"id":2}"##, + )?; + + // Test that we get the corret model output + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":2,"result":{"edit":{"changes":{"file:///fake.py":[{"newText":"\nfrom typing import Dict\n\ndef fib(n: int) -> int:\n memo: Dict[int, int] = {0: 0, 1: 1}\n \n def calculate_fibonacci(k: int) -> int:\n if k not in memo:\n memo[k] = calculate_fibonacci(k - 1) + calculate_fibonacci(k - 2)\n return memo[k]\n \n return calculate_fibonacci(n)\n","range":{"end":{"character":0,"line":8},"start":{"character":0,"line":0}}}]}},"title":"Refactor"}}"## + ); + + child.kill()?; + Ok(()) +} From 57ee3c0936f28d6b24730d7b30ef9e11d883b978 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 17 Aug 2024 12:39:32 -0700 Subject: [PATCH 2/2] Bump version --- crates/lsp-ai/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/lsp-ai/Cargo.toml b/crates/lsp-ai/Cargo.toml index 088e5b3..c329001 100644 --- a/crates/lsp-ai/Cargo.toml +++ b/crates/lsp-ai/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lsp-ai" -version = "0.5.1" +version = "0.6.0" description.workspace = true repository.workspace = true