diff --git a/Cargo.toml b/Cargo.toml index f98cef6..18dfb33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lsp-ai" -version = "0.2.0" +version = "0.3.0" edition = "2021" license = "MIT" description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them." diff --git a/src/config.rs b/src/config.rs index 82df19e..8cbeadd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,6 +9,21 @@ const fn max_requests_per_second_default() -> f32 { 1. } +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct PostProcess { + pub remove_duplicate_start: bool, + pub remove_duplicate_end: bool, +} + +impl Default for PostProcess { + fn default() -> Self { + Self { + remove_duplicate_start: true, + remove_duplicate_end: true, + } + } +} + #[derive(Debug, Clone, Deserialize)] pub enum ValidMemoryBackend { #[serde(rename = "file_store")] @@ -177,10 +192,12 @@ pub struct Anthropic { pub struct Completion { // The model key to use pub model: String, - // Args are deserialized by the backend using them #[serde(default)] pub parameters: Kwargs, + // Parameters for post processing + #[serde(default)] + pub post_process: PostProcess, } #[derive(Clone, Debug, Deserialize)] @@ -230,6 +247,10 @@ impl Config { self.config.completion.is_some() } + pub fn get_completions_post_process(&self) -> Option<&PostProcess> { + self.config.completion.as_ref().map(|x| &x.post_process) + } + pub fn get_completion_transformer_max_requests_per_second(&self) -> anyhow::Result { match &self .config @@ -335,6 +356,10 @@ mod test { "options": { "num_predict": 32 } + }, + "post_process": { + "remove_duplicate_start": true, + "remove_duplicate_end": true, } } } diff --git a/src/custom_requests/generation.rs b/src/custom_requests/generation.rs index e2b3450..725cccb 100644 --- a/src/custom_requests/generation.rs +++ b/src/custom_requests/generation.rs @@ -2,20 +2,27 @@ use lsp_types::TextDocumentPositionParams; use serde::{Deserialize, Serialize}; use serde_json::Value; +use crate::config; + pub enum Generation {} -#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationParams { // This field was "mixed-in" from TextDocumentPositionParams #[serde(flatten)] pub text_document_position: TextDocumentPositionParams, + // The model key to use pub model: String, #[serde(default)] + // Args are deserialized by the backend using them pub parameters: Value, + // Parameters for post processing + #[serde(default)] + pub post_process: config::PostProcess, } -#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerateResult { pub generated_text: String, diff --git a/src/transformer_worker.rs b/src/transformer_worker.rs index a6b184f..196447b 100644 --- a/src/transformer_worker.rs +++ b/src/transformer_worker.rs @@ -11,9 +11,10 @@ use std::time::{Duration, SystemTime}; use tokio::sync::oneshot; use tracing::{error, instrument}; -use crate::config::Config; +use crate::config::{self, Config}; use crate::custom_requests::generation::{GenerateResult, GenerationParams}; use crate::custom_requests::generation_stream::GenerationStreamParams; +use crate::memory_backends::Prompt; use crate::memory_worker::{self, FilterRequest, PromptRequest}; use crate::transformer_backends::TransformerBackend; use crate::utils::ToResponseError; @@ -85,6 +86,83 @@ pub struct DoGenerationStreamResponse { pub generated_text: String, } +fn post_process_start(response: String, front: &str) -> String { + let mut front_match = response.len(); + loop { + if response.len() == 0 || front.ends_with(&response[..front_match]) { + break; + } else { + front_match -= 1; + } + } + if front_match > 0 { + (&response[front_match..]).to_owned() + } else { + response + } +} + +fn post_process_end(response: String, back: &str) -> String { + let mut back_match = 0; + loop { + if back_match == response.len() { + break; + } else if back.starts_with(&response[back_match..]) { + break; + } else { + back_match += 1; + } + } + if back_match > 0 { + (&response[..back_match]).to_owned() + } else { + response + } +} + +// Some basic post processing that will clean up duplicate characters at the front and back +fn post_process_response( + response: String, + prompt: &Prompt, + config: &config::PostProcess, +) -> String { + match prompt { + Prompt::ContextAndCode(context_and_code) => { + if context_and_code.code.contains("") { + let mut split = context_and_code.code.split(""); + let response = if config.remove_duplicate_start { + post_process_start(response, split.next().unwrap()) + } else { + response + }; + if config.remove_duplicate_end { + post_process_end(response, split.next().unwrap()) + } else { + response + } + } else { + if config.remove_duplicate_start { + post_process_start(response, &context_and_code.code) + } else { + response + } + } + } + Prompt::FIM(fim) => { + let response = if config.remove_duplicate_start { + post_process_start(response, &fim.prompt) + } else { + response + }; + if config.remove_duplicate_end { + post_process_end(response, &fim.suffix) + } else { + response + } + } + } +} + pub fn run( transformer_backends: HashMap>, memory_tx: std::sync::mpsc::Sender, @@ -249,6 +327,7 @@ async fn do_completion( ) .unwrap(); + // Build the prompt let (tx, rx) = oneshot::channel(); memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new( request.params.text_document_position.clone(), @@ -258,13 +337,22 @@ async fn do_completion( )))?; let prompt = rx.await?; + // Get the filter text let (tx, rx) = oneshot::channel(); memory_backend_tx.send(memory_worker::WorkerRequest::FilterText( FilterRequest::new(request.params.text_document_position.clone(), tx), ))?; let filter_text = rx.await?; - let response = transformer_backend.do_completion(&prompt, params).await?; + // Get the response + let mut response = transformer_backend.do_completion(&prompt, params).await?; + eprintln!("\n\n\n\nGOT RESPONSE: {}\n\n\n\n", response.insert_text); + + if let Some(post_process) = config.get_completions_post_process() { + response.insert_text = post_process_response(response.insert_text, &prompt, &post_process); + } + + // Build and send the response let completion_text_edit = TextEdit::new( Range::new( Position::new( @@ -314,7 +402,13 @@ async fn do_generate( )))?; let prompt = rx.await?; - let response = transformer_backend.do_generate(&prompt, params).await?; + let mut response = transformer_backend.do_generate(&prompt, params).await?; + response.generated_text = post_process_response( + response.generated_text, + &prompt, + &request.params.post_process, + ); + let result = GenerateResult { generated_text: response.generated_text, }; @@ -325,3 +419,67 @@ async fn do_generate( error: None, }) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory_backends::{ContextAndCodePrompt, FIMPrompt}; + + #[test] + fn test_post_process_fim() { + let config = config::PostProcess::default(); + + let prompt = Prompt::FIM(FIMPrompt { + prompt: "test 1234 ".to_string(), + suffix: "ttabc".to_string(), + }); + let response = "4 zz tta".to_string(); + let new_response = post_process_response(response.clone(), &prompt, &config); + assert_eq!(new_response, "zz "); + + let prompt = Prompt::FIM(FIMPrompt { + prompt: "test".to_string(), + suffix: "test".to_string(), + }); + let response = "zzzz".to_string(); + let new_response = post_process_response(response.clone(), &prompt, &config); + assert_eq!(new_response, "zzzz"); + } + + #[test] + fn test_post_process_context_and_code() { + let config = config::PostProcess::default(); + + let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { + context: "".to_string(), + code: "tt ".to_string(), + }); + let response = "tt abc".to_string(); + let new_response = post_process_response(response.clone(), &prompt, &config); + assert_eq!(new_response, "abc"); + + let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { + context: "".to_string(), + code: "ff".to_string(), + }); + let response = "zz".to_string(); + let new_response = post_process_response(response.clone(), &prompt, &config); + assert_eq!(new_response, "zz"); + + let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { + context: "".to_string(), + code: "tt tt".to_string(), + }); + let response = "tt abc tt".to_string(); + let new_response = post_process_response(response.clone(), &prompt, &config); + assert_eq!(new_response, "abc"); + + let prompt = Prompt::ContextAndCode(ContextAndCodePrompt { + context: "".to_string(), + code: "dd".to_string(), + }); + let response = "zz".to_string(); + let new_response = post_process_response(response.clone(), &prompt, &config); + assert_eq!(new_response, "zz"); + } +}