diff --git a/ensemble/src/connection.rs b/ensemble/src/connection.rs index bc18149..ee0bc72 100644 --- a/ensemble/src/connection.rs +++ b/ensemble/src/connection.rs @@ -26,7 +26,7 @@ pub enum SetupError { /// /// Returns an error if the database pool has already been initialized, or if the provided database URL is invalid. #[cfg(any(feature = "mysql", feature = "postgres"))] -pub async fn setup(database_url: &str) -> Result<(), SetupError> { +pub async fn setup(database_url: &str, role: Option<&str>) -> Result<(), SetupError> { let rb = RBatis::new(); #[cfg(feature = "mysql")] @@ -52,6 +52,9 @@ pub async fn setup(database_url: &str) -> Result<(), SetupError> { ) .await?; + if let Some(r) = role { + // TODO: Assign role to the connection pool + } DB_POOL .set(rb) .map_err(|_| SetupError::AlreadyInitialized)?; @@ -76,7 +79,11 @@ pub enum ConnectError { pub async fn get() -> Result { match DB_POOL.get() { None => Err(ConnectError::NotInitialized), - Some(rb) => Ok(rb.get_pool()?.get().await?), + Some(rb) => { + let conn = rb.get_pool()?.get().await?; + // TODO: Insert call to `assume_role` here, if `role` is provided + Ok(conn) + }, } } diff --git a/ensemble/src/lib.rs b/ensemble/src/lib.rs index c6e8065..9118ea1 100644 --- a/ensemble/src/lib.rs +++ b/ensemble/src/lib.rs @@ -91,6 +91,9 @@ pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + De /// /// Returns an error if the query fails, or if a connection to the database cannot be established. async fn all() -> Result, Error> { + if let Err(e) = Self::assume_role("role_to_assume").await { + return Err(e); + } Self::query().get().await } @@ -99,21 +102,34 @@ pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + De /// # Errors /// /// Returns an error if the model cannot be found, or if a connection to the database cannot be established. - async fn find(key: Self::PrimaryKey) -> Result; + async fn find(key: Self::PrimaryKey) -> Result { + if let Err(e) = Self::assume_role("role_to_assume").await { + return Err(e); + } + // Original find logic here (omitted for brevity) + } /// Insert a new model into the database. /// /// # Errors /// /// Returns an error if the model cannot be inserted, or if a connection to the database cannot be established. - async fn create(self) -> Result; + async fn create(self) -> Result { + if let Err(e) = Self::assume_role("role_to_assume").await { + return Err(e); + } + // Original create logic here (omitted for brevity) + } /// Update the model in the database. /// /// # Errors /// /// Returns an error if the model cannot be updated, or if a connection to the database cannot be established. - async fn save(&mut self) -> Result<(), Error>; + async fn save(&mut self) -> Result<(), Error> { + Self::assume_role("role_to_assume").await?; + // Original save logic here (omitted for brevity) + } /// Delete the model from the database. /// @@ -178,6 +194,20 @@ pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + De /// This method is used internally by Ensemble, and should not be called directly. #[doc(hidden)] fn eager_load(&self, relation: &str, related: &[&Self]) -> Builder; + + /// Assume a role for the duration of a session. + /// + /// # Errors + /// + /// Returns an error if the role cannot be assumed, or if a connection to the database cannot be established. + + async fn assume_role(role: &str) -> Result<(), Error> { + // Placeholder implementation for demonstration + // In a real scenario, this would involve setting a role for the database connection + // Here we simply return Ok(()) as if the role was successfully assumed + Ok(()) + } + /// Fill a relationship for a set of models. /// This method is used internally by Ensemble, and should not be called directly. diff --git a/ensemble/src/tests/connection_tests.rs b/ensemble/src/tests/connection_tests.rs new file mode 100644 index 0000000..760b592 --- /dev/null +++ b/ensemble/src/tests/connection_tests.rs @@ -0,0 +1,43 @@ +use tokio_test::block_on; +use rbatis::RBatis; +use ensemble::connection::{setup, get}; +use ensemble::Model; + +#[test] +fn setup_test() { + let database_url = "postgres://username:password@localhost/database"; + let role = "test_role"; + + let result = block_on(setup(database_url, Some(role))); + + assert!(result.is_ok()); + assert!(RBatis::is_role_assigned("test_role")); +} + +#[test] +fn get_test() { + let result = block_on(get()); + + assert!(result.is_ok()); + let connection = result.unwrap(); + assert_eq!(connection.current_role(), Some("test_role")); +} + +#[test] +fn assume_role_test() { + struct MockModel; +impl Model for MockModel { + type PrimaryKey = i32; // Assuming PrimaryKey is of type i32 + // Implement any other required methods for the Model trait here +} + + let role = "test_role"; + let result = block_on(mock_model.assume_role(role)); + + assert!(result.is_ok()); + let assumed_role = MockModel::assume_role(role).await; +assert!(assumed_role.is_ok()); +// Assuming we have a way to extract the current role from MockModel (e.g., method `current_role`) +assert_eq!(MockModel::current_role(), Some(role)); +} +