Skip to content

Commit

Permalink
Merge pull request #23 from SilasMarvin/silas-add-initial-post-proces…
Browse files Browse the repository at this point in the history
…sing

Added initial post processing to remove duplicate start and end characters
  • Loading branch information
SilasMarvin authored Jun 14, 2024
2 parents 736b137 + 17cb6cf commit cd46ecf
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lsp-ai"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
license = "MIT"
description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them."
Expand Down
27 changes: 26 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@ const fn max_requests_per_second_default() -> f32 {
1.
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct PostProcess {
pub remove_duplicate_start: bool,
pub remove_duplicate_end: bool,
}

impl Default for PostProcess {
fn default() -> Self {
Self {
remove_duplicate_start: true,
remove_duplicate_end: true,
}
}
}

#[derive(Debug, Clone, Deserialize)]
pub enum ValidMemoryBackend {
#[serde(rename = "file_store")]
Expand Down Expand Up @@ -177,10 +192,12 @@ pub struct Anthropic {
pub struct Completion {
// The model key to use
pub model: String,

// Args are deserialized by the backend using them
#[serde(default)]
pub parameters: Kwargs,
// Parameters for post processing
#[serde(default)]
pub post_process: PostProcess,
}

#[derive(Clone, Debug, Deserialize)]
Expand Down Expand Up @@ -230,6 +247,10 @@ impl Config {
self.config.completion.is_some()
}

pub fn get_completions_post_process(&self) -> Option<&PostProcess> {
self.config.completion.as_ref().map(|x| &x.post_process)
}

pub fn get_completion_transformer_max_requests_per_second(&self) -> anyhow::Result<f32> {
match &self
.config
Expand Down Expand Up @@ -335,6 +356,10 @@ mod test {
"options": {
"num_predict": 32
}
},
"post_process": {
"remove_duplicate_start": true,
"remove_duplicate_end": true,
}
}
}
Expand Down
11 changes: 9 additions & 2 deletions src/custom_requests/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@ use lsp_types::TextDocumentPositionParams;
use serde::{Deserialize, Serialize};
use serde_json::Value;

use crate::config;

pub enum Generation {}

#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationParams {
// This field was "mixed-in" from TextDocumentPositionParams
#[serde(flatten)]
pub text_document_position: TextDocumentPositionParams,
// The model key to use
pub model: String,
#[serde(default)]
// Args are deserialized by the backend using them
pub parameters: Value,
// Parameters for post processing
#[serde(default)]
pub post_process: config::PostProcess,
}

#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateResult {
pub generated_text: String,
Expand Down
164 changes: 161 additions & 3 deletions src/transformer_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use std::time::{Duration, SystemTime};
use tokio::sync::oneshot;
use tracing::{error, instrument};

use crate::config::Config;
use crate::config::{self, Config};
use crate::custom_requests::generation::{GenerateResult, GenerationParams};
use crate::custom_requests::generation_stream::GenerationStreamParams;
use crate::memory_backends::Prompt;
use crate::memory_worker::{self, FilterRequest, PromptRequest};
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
Expand Down Expand Up @@ -85,6 +86,83 @@ pub struct DoGenerationStreamResponse {
pub generated_text: String,
}

fn post_process_start(response: String, front: &str) -> String {
let mut front_match = response.len();
loop {
if response.len() == 0 || front.ends_with(&response[..front_match]) {
break;
} else {
front_match -= 1;
}
}
if front_match > 0 {
(&response[front_match..]).to_owned()
} else {
response
}
}

fn post_process_end(response: String, back: &str) -> String {
let mut back_match = 0;
loop {
if back_match == response.len() {
break;
} else if back.starts_with(&response[back_match..]) {
break;
} else {
back_match += 1;
}
}
if back_match > 0 {
(&response[..back_match]).to_owned()
} else {
response
}
}

// Some basic post processing that will clean up duplicate characters at the front and back
fn post_process_response(
response: String,
prompt: &Prompt,
config: &config::PostProcess,
) -> String {
match prompt {
Prompt::ContextAndCode(context_and_code) => {
if context_and_code.code.contains("<CURSOR>") {
let mut split = context_and_code.code.split("<CURSOR>");
let response = if config.remove_duplicate_start {
post_process_start(response, split.next().unwrap())
} else {
response
};
if config.remove_duplicate_end {
post_process_end(response, split.next().unwrap())
} else {
response
}
} else {
if config.remove_duplicate_start {
post_process_start(response, &context_and_code.code)
} else {
response
}
}
}
Prompt::FIM(fim) => {
let response = if config.remove_duplicate_start {
post_process_start(response, &fim.prompt)
} else {
response
};
if config.remove_duplicate_end {
post_process_end(response, &fim.suffix)
} else {
response
}
}
}
}

pub fn run(
transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>>,
memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
Expand Down Expand Up @@ -249,6 +327,7 @@ async fn do_completion(
)
.unwrap();

// Build the prompt
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
request.params.text_document_position.clone(),
Expand All @@ -258,13 +337,22 @@ async fn do_completion(
)))?;
let prompt = rx.await?;

// Get the filter text
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::FilterText(
FilterRequest::new(request.params.text_document_position.clone(), tx),
))?;
let filter_text = rx.await?;

let response = transformer_backend.do_completion(&prompt, params).await?;
// Get the response
let mut response = transformer_backend.do_completion(&prompt, params).await?;
eprintln!("\n\n\n\nGOT RESPONSE: {}\n\n\n\n", response.insert_text);

if let Some(post_process) = config.get_completions_post_process() {
response.insert_text = post_process_response(response.insert_text, &prompt, &post_process);
}

// Build and send the response
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
Expand Down Expand Up @@ -314,7 +402,13 @@ async fn do_generate(
)))?;
let prompt = rx.await?;

let response = transformer_backend.do_generate(&prompt, params).await?;
let mut response = transformer_backend.do_generate(&prompt, params).await?;
response.generated_text = post_process_response(
response.generated_text,
&prompt,
&request.params.post_process,
);

let result = GenerateResult {
generated_text: response.generated_text,
};
Expand All @@ -325,3 +419,67 @@ async fn do_generate(
error: None,
})
}

#[cfg(test)]
mod tests {
use super::*;
use crate::memory_backends::{ContextAndCodePrompt, FIMPrompt};

#[test]
fn test_post_process_fim() {
let config = config::PostProcess::default();

let prompt = Prompt::FIM(FIMPrompt {
prompt: "test 1234 ".to_string(),
suffix: "ttabc".to_string(),
});
let response = "4 zz tta".to_string();
let new_response = post_process_response(response.clone(), &prompt, &config);
assert_eq!(new_response, "zz ");

let prompt = Prompt::FIM(FIMPrompt {
prompt: "test".to_string(),
suffix: "test".to_string(),
});
let response = "zzzz".to_string();
let new_response = post_process_response(response.clone(), &prompt, &config);
assert_eq!(new_response, "zzzz");
}

#[test]
fn test_post_process_context_and_code() {
let config = config::PostProcess::default();

let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
context: "".to_string(),
code: "tt ".to_string(),
});
let response = "tt abc".to_string();
let new_response = post_process_response(response.clone(), &prompt, &config);
assert_eq!(new_response, "abc");

let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
context: "".to_string(),
code: "ff".to_string(),
});
let response = "zz".to_string();
let new_response = post_process_response(response.clone(), &prompt, &config);
assert_eq!(new_response, "zz");

let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
context: "".to_string(),
code: "tt <CURSOR> tt".to_string(),
});
let response = "tt abc tt".to_string();
let new_response = post_process_response(response.clone(), &prompt, &config);
assert_eq!(new_response, "abc");

let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
context: "".to_string(),
code: "d<CURSOR>d".to_string(),
});
let response = "zz".to_string();
let new_response = post_process_response(response.clone(), &prompt, &config);
assert_eq!(new_response, "zz");
}
}

0 comments on commit cd46ecf

Please sign in to comment.