Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit testing #15

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros"] }
dotenv = "0.15.0"
serde_json = "1.0.128"
serde = { version = "1.0.210", features = ["derive"] }
comfy-table = "7.1.1"
comfy-table = "7.1.1"
httpmock = "0.7.0"
11 changes: 4 additions & 7 deletions src/api/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::json_models::chat_completion::ChatCompletionResponse;
use super::Instance;
use crate::clue::ClueCollection;
use crate::json_models::chat_completion::ChatCompletionResponse;
use serde_json::json;

const SYSTEM_PROMPT: &str = r#"
Expand Down Expand Up @@ -36,21 +36,18 @@ impl Instance {
.json(&request_body)
.send()
.await
.map_err(|_| "Failed to fetch clue collection from API server")?;
.map_err(|e| format!("Failed to fetch clue collection from API server: {}", e))?;

let parsed_response = response
.json::<ChatCompletionResponse>()
.await
.map_err(|_| "Failed to parse clues from API server")?;
.map_err(|e| format!("Failed to parse clues from API server: {}", e))?;

// Extract usage information from the parsed response
let token_usage = parsed_response.usage;

// Extract clue strings from the parsed response
let clue_strings = parsed_response
.choices
.first()
.ok_or("Failed to parse clues from API server")?
let clue_strings = parsed_response.choices[0]
.message
.content
.lines()
Expand Down
8 changes: 4 additions & 4 deletions src/api/language_models.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use crate::json_models::language_model::ModelsResponse;
use super::Instance;
use crate::json_models::language_model::ModelsResponse;

impl Instance {
pub async fn fetch_all_model_ids(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
pub async fn fetch_language_model_ids(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let response = self
.client
.get(format!("{}models", self.base_url))
.bearer_auth(&self.key)
.send()
.await
.map_err(|_| "Failed to fetch model IDs from API server")?;
.map_err(|e| format!("Failed to fetch model IDs from API server: {}", e))?;

let mut all_model_ids = response
.json::<ModelsResponse>()
.await
.map_err(|_| "Failed to parse model IDs from API server")?
.map_err(|e| format!("Failed to parse model IDs from API server: {}", e))?
.data
.iter()
.map(|model| model.id.trim().to_string())
Expand Down
6 changes: 5 additions & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Instance {
model_id: String,
) -> Result<(), Box<dyn std::error::Error>> {
// Return Error if the chosen model is not valid
let valid_model_ids = self.fetch_all_model_ids().await?;
let valid_model_ids = self.fetch_language_model_ids().await?;
if !valid_model_ids.contains(&model_id) {
return Err(format!(
"{} is not a valid language model from your provider",
Expand All @@ -54,4 +54,8 @@ impl Instance {
self.model_id = model_id;
Ok(())
}

pub fn set_base_url(&mut self, base_url: String) {
self.base_url = base_url;
}
}
20 changes: 16 additions & 4 deletions src/clue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl ClueCollection {
pub fn display_table(&self) {
println!("{}", self.generate_table());
}

pub fn display_token_info(&self) {
eprintln!(
"\nToken Usage:\n\
Expand All @@ -108,9 +108,21 @@ impl ClueCollection {
Completion Tokens: {}\n\
----------------------\n\
Total Tokens: {}",
self.usage.prompt_tokens,
self.usage.completion_tokens,
self.usage.total_tokens
self.usage.prompt_tokens, self.usage.completion_tokens, self.usage.total_tokens
);
}

pub fn generate_raw_list(&self) -> String {
let mut raw_list = String::new();
for clue in &self.clues {
let clue_string = format!(
"{} {} - {}\n",
clue.clue_word,
clue.count,
clue.linked_words.join(", ")
);
raw_list.push_str(clue_string.as_str());
}
raw_list
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::path::PathBuf;
pub mod api;
mod clue;
mod json_models;
#[cfg(test)]
mod tests;

/// Mastermind - An LLM-powered CLI tool to help you be a better spymaster in Codenames
#[derive(Parser)]
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// If -g is set, call the models API endpoint instead
if args.get {
println!("{}", api_instance.fetch_all_model_ids().await?.join("\n"));
println!("{}", api_instance.fetch_language_model_ids().await?.join("\n"));
return Ok(());
}

Expand Down
5 changes: 5 additions & 0 deletions src/tests/expected_outputs/chat_completions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
music 2 - sound, bee
film 2 - bond, tokyo
free 2 - park, penny
dive 2 - scuba diver, hospital
large 2 - walrus, scuba diver
13 changes: 13 additions & 0 deletions src/tests/expected_outputs/language_models.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
distil-whisper-large-v3-en
gemma-7b-it
gemma2-9b-it
llama-3.1-70b-versatile
llama-3.1-8b-instant
llama-guard-3-8b
llama3-70b-8192
llama3-8b-8192
llama3-groq-70b-8192-tool-use-preview
llama3-groq-8b-8192-tool-use-preview
llava-v1.5-7b-4096-preview
mixtral-8x7b-32768
whisper-large-v3
30 changes: 30 additions & 0 deletions src/tests/mock_responses/chat_completions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"id": "chatcmpl-869ede85-2f46-4834-a039-28d757e958a5",
"object": "chat.completion",
"created": 1726870549,
"model": "llama-3.1-70b-versatile",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "music, 2, sound, bee\nfilm, 2, bond, tokyo\nfree, 2, park, penny\ndive, 2, scuba diver, hospital\nlarge, 2, walrus, scuba diver"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"queue_time": 0.005669406999999987,
"prompt_tokens": 222,
"prompt_time": 0.068204384,
"completion_tokens": 53,
"completion_time": 0.214023764,
"total_tokens": 275,
"total_time": 0.282228148
},
"system_fingerprint": "fp_b6828be2c9",
"x_groq": {
"id": "req_01j88r2wfmecr9zgpjn2zmnprb"
}
}
122 changes: 122 additions & 0 deletions src/tests/mock_responses/language_models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"object": "list",
"data": [
{
"id": "llava-v1.5-7b-4096-preview",
"object": "model",
"created": 1725402373,
"owned_by": "Other",
"active": true,
"context_window": 4096,
"public_apps": null
},
{
"id": "gemma-7b-it",
"object": "model",
"created": 1693721698,
"owned_by": "Google",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama-3.1-8b-instant",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 131072,
"public_apps": null
},
{
"id": "whisper-large-v3",
"object": "model",
"created": 1693721698,
"owned_by": "OpenAI",
"active": true,
"context_window": 448,
"public_apps": null
},
{
"id": "mixtral-8x7b-32768",
"object": "model",
"created": 1693721698,
"owned_by": "Mistral AI",
"active": true,
"context_window": 32768,
"public_apps": null
},
{
"id": "llama3-8b-8192",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "distil-whisper-large-v3-en",
"object": "model",
"created": 1693721698,
"owned_by": "Hugging Face",
"active": true,
"context_window": 448,
"public_apps": null
},
{
"id": "llama-guard-3-8b",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "gemma2-9b-it",
"object": "model",
"created": 1693721698,
"owned_by": "Google",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama3-70b-8192",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama3-groq-70b-8192-tool-use-preview",
"object": "model",
"created": 1693721698,
"owned_by": "Groq",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama-3.1-70b-versatile",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 131072,
"public_apps": null
},
{
"id": "llama3-groq-8b-8192-tool-use-preview",
"object": "model",
"created": 1693721698,
"owned_by": "Groq",
"active": true,
"context_window": 8192,
"public_apps": null
}
]
}
75 changes: 75 additions & 0 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::*;
use crate::api::Instance;
use httpmock::prelude::*;

#[test]
fn test_api_instance() {
let api_instance = api::Instance::new();
assert!(api_instance.is_ok());
}

#[test]
fn test_read_words_from_file() {
let to_link = read_words_from_file(PathBuf::from("examples/link.txt"));
assert!(to_link.is_ok());
let to_avoid = read_words_from_file(PathBuf::from("examples/avoid.txt"));
assert!(to_avoid.is_ok());
}

#[tokio::test]
async fn test_fetch_language_models() {
// Start a lightweight mock server.
let server = MockServer::start_async().await;

// Create a mock on the server.
let mock = server.mock(|when, then| {
when.method(GET).path("/models");
then.status(200)
.header("content-type", "application/json")
.body_from_file("src/tests/mock_responses/language_models.json");
});

// Create an API instance and set the base url to mock server url
let mut api_instance = Instance::new().unwrap();
api_instance.set_base_url(server.url("/"));

// Get response from mock server
let response = api_instance.fetch_language_model_ids().await.unwrap();
mock.assert();

// Compare outputs
let output = response.join("\n");
let expected_output = fs::read_to_string("src/tests/expected_outputs/language_models.txt").unwrap();
assert_eq!(output, expected_output);
}

#[tokio::test]
async fn test_fetch_clue_collection() {
// Start a lightweight mock server.
let server = MockServer::start_async().await;

// Create a mock on the server.
let mock = server.mock(|when, then| {
when.method(POST).path("/chat/completions");
then.status(200)
.header("content-type", "application/json")
.body_from_file("src/tests/mock_responses/chat_completions.json");
});

// Create an API instance and set the base url to mock server url
let mut api_instance = Instance::new().unwrap();
api_instance.set_base_url(server.url("/"));

// Get response from mock server
let response = api_instance
.fetch_clue_collection(vec![], vec![])
.await
.unwrap();
mock.assert();

// Compare outputs
let output = response.generate_raw_list();
let expected_output =
fs::read_to_string("src/tests/expected_outputs/chat_completions.txt").unwrap();
assert_eq!(output, expected_output);
}