Skip to content

Commit

Permalink
Merge pull request #20 from theoforger/feature/implement-traits
Browse files Browse the repository at this point in the history
Implement Display/Debug traits
  • Loading branch information
theoforger authored Oct 2, 2024
2 parents 5d00b04 + ccb46e7 commit 726efb5
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 32 deletions.
6 changes: 1 addition & 5 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::env;

pub struct Instance {
client: reqwest::Client,
base_url: String,
pub base_url: String,
key: String,
}

Expand All @@ -33,8 +33,4 @@ impl Instance {
env::var(var_name)
.map_err(|_| format!("Cannot read environment variable: {}", var_name).into())
}

pub fn set_base_url(&mut self, base_url: String) {
self.base_url = base_url;
}
}
28 changes: 17 additions & 11 deletions src/clue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use crate::json::chat_completions::{ChatCompletionsResponse, Usage};
use comfy_table::modifiers::UTF8_ROUND_CORNERS;
use comfy_table::presets::UTF8_FULL;
use comfy_table::{Attribute, Cell, CellAlignment, ContentArrangement, Table};
use std::fmt::{Debug, Display};

struct Clue {
clue_word: String,
count: usize,
Expand Down Expand Up @@ -93,7 +95,7 @@ impl ClueCollection {
self.clues.is_empty()
}

pub fn generate_list(&self) -> String {
fn generate_list(&self) -> String {
let mut list = String::new();
for clue in &self.clues {
let clue_string = format!(
Expand All @@ -107,7 +109,7 @@ impl ClueCollection {
list
}

pub fn generate_table(&self) -> String {
fn generate_table(&self) -> Table {
let mut table = Table::new();

// Set up header and styles
Expand Down Expand Up @@ -146,15 +148,7 @@ impl ClueCollection {
.expect("The table should have more than 2 columns");
second_column.set_cell_alignment(CellAlignment::Center);

table.to_string()
}

pub fn display_list(&self) {
println!("{}", self.generate_list());
}

pub fn display_table(&self) {
println!("{}", self.generate_table());
table
}

pub fn display_token_info(&self) {
Expand All @@ -169,3 +163,15 @@ impl ClueCollection {
);
}
}

impl Display for ClueCollection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.generate_table())
}
}

impl Debug for ClueCollection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.generate_list())
}
}
2 changes: 1 addition & 1 deletion src/json/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pub struct Model {
#[derive(Deserialize)]
pub struct ModelsResponse {
pub data: Vec<Model>,
}
}
6 changes: 3 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box<dyn Error>> {

// If -g is set, display models and exit the program
if args.get {
model_collection.display_list();
println!("{model_collection}");
return Ok(());
}

Expand Down Expand Up @@ -91,9 +91,9 @@ fn handle_output(args: &Args, clue_collection: ClueCollection) -> Result<(), Box
println!("The language model didn't return any useful clues. Maybe try again?");
} else if let Some(output_path) = &args.output {
println!("Writing to file '{}'...", output_path.display());
write_content_to_file(output_path, clue_collection.generate_table())?;
write_content_to_file(output_path, clue_collection.to_string())?;
} else {
clue_collection.display_table();
println!("{clue_collection}");
}

// If -t is set, output token usage information
Expand Down
15 changes: 7 additions & 8 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::json::models::ModelsResponse;
use dialoguer::MultiSelect;
use std::fmt::{Display, Formatter};

pub struct ModelCollection {
model_ids: Vec<String>,
Expand Down Expand Up @@ -33,14 +34,6 @@ impl ModelCollection {
chosen_model_ids
}

pub fn generate_list(&self) -> String {
self.model_ids.join("\n")
}

pub fn display_list(&self) {
println!("{}", self.generate_list());
}

pub fn validate_model_id(&self, model_id: &String) -> Result<(), Box<dyn std::error::Error>> {
if !self.model_ids.contains(model_id) {
return Err(format!(
Expand All @@ -53,3 +46,9 @@ impl ModelCollection {
Ok(())
}
}

impl Display for ModelCollection {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.model_ids.join("\n"))
}
}
8 changes: 4 additions & 4 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ async fn test_get_models() {

// Create an API instance and set the base url to mock server url
let mut api_instance = Instance::new().unwrap();
api_instance.set_base_url(server.url("/"));
api_instance.base_url = server.url("/");

// Get response from mock server
let response = ModelCollection::new(api_instance.get_models().await.unwrap());
mock.assert();

// Compare outputs
let output = response.generate_list();
let output = response.to_string();
let expected_output = fs::read_to_string("src/tests/expected_outputs/models.txt").unwrap();
assert_eq!(output, expected_output);
}
Expand All @@ -60,7 +60,7 @@ async fn test_post_chat_completions() {

// Create an API instance and set the base url to mock server url
let mut api_instance = Instance::new().unwrap();
api_instance.set_base_url(server.url("/"));
api_instance.base_url = server.url("/");

// Get responses from mock server
let responses = vec![api_instance
Expand All @@ -70,7 +70,7 @@ async fn test_post_chat_completions() {
mock.assert();

// Compare outputs
let output = ClueCollection::new(responses).generate_list();
let output = format!("{:?}", ClueCollection::new(responses));
let expected_output =
fs::read_to_string("src/tests/expected_outputs/chat_completions.txt").unwrap();
assert_eq!(output, expected_output);
Expand Down

0 comments on commit 726efb5

Please sign in to comment.