From 142e72ef66f30f9964fdb0509448a5ef2014c48d Mon Sep 17 00:00:00 2001 From: theoforger Date: Fri, 27 Sep 2024 15:54:29 -0400 Subject: [PATCH 1/4] Implement `Display` trait --- src/api/mod.rs | 6 +----- src/clue.rs | 27 ++++++++++++++++----------- src/main.rs | 6 +++--- src/model.rs | 15 +++++++-------- src/tests/mod.rs | 8 ++++---- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index a7cecb4..beb0d33 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -6,7 +6,7 @@ use std::env; pub struct Instance { client: reqwest::Client, - base_url: String, + pub base_url: String, key: String, } @@ -33,8 +33,4 @@ impl Instance { env::var(var_name) .map_err(|_| format!("Cannot read environment variable: {}", var_name).into()) } - - pub fn set_base_url(&mut self, base_url: String) { - self.base_url = base_url; - } } diff --git a/src/clue.rs b/src/clue.rs index 025d796..11069e0 100644 --- a/src/clue.rs +++ b/src/clue.rs @@ -2,6 +2,7 @@ use crate::json::chat_completions::{ChatCompletionsResponse, Usage}; use comfy_table::modifiers::UTF8_ROUND_CORNERS; use comfy_table::presets::UTF8_FULL; use comfy_table::{Attribute, Cell, CellAlignment, ContentArrangement, Table}; +use std::fmt::{Debug, Display}; struct Clue { clue_word: String, count: usize, @@ -93,7 +94,7 @@ impl ClueCollection { self.clues.is_empty() } - pub fn generate_list(&self) -> String { + fn generate_list(&self) -> String { let mut list = String::new(); for clue in &self.clues { let clue_string = format!( @@ -107,7 +108,7 @@ impl ClueCollection { list } - pub fn generate_table(&self) -> String { + fn generate_table(&self) -> Table { let mut table = Table::new(); // Set up header and styles @@ -146,15 +147,7 @@ impl ClueCollection { .expect("The table should have more than 2 columns"); second_column.set_cell_alignment(CellAlignment::Center); - table.to_string() - } - - pub fn display_list(&self) { - println!("{}", self.generate_list()); - } - - pub fn display_table(&self) { - println!("{}", self.generate_table()); + table } pub fn display_token_info(&self) { @@ -169,3 +162,15 @@ impl ClueCollection { ); } } + +impl Display for ClueCollection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.generate_table()) + } +} + +impl Debug for ClueCollection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.generate_list()) + } +} diff --git a/src/main.rs b/src/main.rs index c41d117..368f9be 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { // If -g is set, display models and exit the program if args.get { - model_collection.display_list(); + println!("{model_collection}"); return Ok(()); } @@ -91,9 +91,9 @@ fn handle_output(args: &Args, clue_collection: ClueCollection) -> Result<(), Box println!("The language model didn't return any useful clues. Maybe try again?"); } else if let Some(output_path) = &args.output { println!("Writing to file '{}'...", output_path.display()); - write_content_to_file(output_path, clue_collection.generate_table())?; + write_content_to_file(output_path, clue_collection.to_string())?; } else { - clue_collection.display_table(); + println!("{clue_collection}"); } // If -t is set, output token usage information diff --git a/src/model.rs b/src/model.rs index 9c9d551..7b2a658 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,6 @@ use crate::json::models::ModelsResponse; use dialoguer::MultiSelect; +use std::fmt::{Display, Formatter}; pub struct ModelCollection { model_ids: Vec, @@ -33,14 +34,6 @@ impl ModelCollection { chosen_model_ids } - pub fn generate_list(&self) -> String { - self.model_ids.join("\n") - } - - pub fn display_list(&self) { - println!("{}", self.generate_list()); - } - pub fn validate_model_id(&self, model_id: &String) -> Result<(), Box> { if !self.model_ids.contains(model_id) { return Err(format!( @@ -53,3 +46,9 @@ impl ModelCollection { Ok(()) } } + +impl Display for ModelCollection { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.model_ids.join("\n")) + } +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index aa0917b..06041c5 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -33,14 +33,14 @@ async fn test_get_models() { // Create an API instance and set the base url to mock server url let mut api_instance = Instance::new().unwrap(); - api_instance.set_base_url(server.url("/")); + api_instance.base_url = server.url("/"); // Get response from mock server let response = ModelCollection::new(api_instance.get_models().await.unwrap()); mock.assert(); // Compare outputs - let output = response.generate_list(); + let output = response.to_string(); let expected_output = fs::read_to_string("src/tests/expected_outputs/models.txt").unwrap(); assert_eq!(output, expected_output); } @@ -60,7 +60,7 @@ async fn test_post_chat_completions() { // Create an API instance and set the base url to mock server url let mut api_instance = Instance::new().unwrap(); - api_instance.set_base_url(server.url("/")); + api_instance.base_url = server.url("/"); // Get responses from mock server let responses = vec![api_instance @@ -70,7 +70,7 @@ async fn test_post_chat_completions() { mock.assert(); // Compare outputs - let output = ClueCollection::new(responses).generate_list(); + let output = ClueCollection::new(responses).to_string(); let expected_output = fs::read_to_string("src/tests/expected_outputs/chat_completions.txt").unwrap(); assert_eq!(output, expected_output); From 06e41c8df2fd306b2c96ae4c2f35a8e01559c090 Mon Sep 17 00:00:00 2001 From: theoforger Date: Wed, 2 Oct 2024 08:54:27 -0400 Subject: [PATCH 2/4] Run cargo fmt --- src/json/models.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/json/models.rs b/src/json/models.rs index cee2d43..54f09a4 100644 --- a/src/json/models.rs +++ b/src/json/models.rs @@ -8,4 +8,4 @@ pub struct Model { #[derive(Deserialize)] pub struct ModelsResponse { pub data: Vec, -} \ No newline at end of file +} From e57b44aa52ede39e5d48428dca06337586adc87c Mon Sep 17 00:00:00 2001 From: theoforger Date: Wed, 2 Oct 2024 09:14:51 -0400 Subject: [PATCH 3/4] Adjust testing --- src/tests/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 06041c5..3f27de1 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -70,7 +70,7 @@ async fn test_post_chat_completions() { mock.assert(); // Compare outputs - let output = ClueCollection::new(responses).to_string(); + let output = format!("{:?}", ClueCollection::new(responses)); let expected_output = fs::read_to_string("src/tests/expected_outputs/chat_completions.txt").unwrap(); assert_eq!(output, expected_output); From ccb46e73583924ab703c7934f33c247fd1ae1e02 Mon Sep 17 00:00:00 2001 From: theoforger Date: Wed, 2 Oct 2024 09:22:18 -0400 Subject: [PATCH 4/4] Add spacing --- src/clue.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/clue.rs b/src/clue.rs index 11069e0..f5c80cc 100644 --- a/src/clue.rs +++ b/src/clue.rs @@ -3,6 +3,7 @@ use comfy_table::modifiers::UTF8_ROUND_CORNERS; use comfy_table::presets::UTF8_FULL; use comfy_table::{Attribute, Cell, CellAlignment, ContentArrangement, Table}; use std::fmt::{Debug, Display}; + struct Clue { clue_word: String, count: usize,