diff --git a/async-nats/src/service/endpoint.rs b/async-nats/src/service/endpoint.rs index bee4e1fe5..477154ecd 100644 --- a/async-nats/src/service/endpoint.rs +++ b/async-nats/src/service/endpoint.rs @@ -183,3 +183,15 @@ pub struct Stats { /// Queue group to which this endpoint is assigned to. pub queue_group: String, } + +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq)] +pub struct Info { + /// Name of the endpoint. + pub name: String, + /// Endpoint subject. + pub subject: String, + /// Queue group to which this endpoint is assigned. + pub queue_group: String, + /// Endpoint-specific metadata. + pub metadata: HashMap, +} diff --git a/async-nats/src/service/mod.rs b/async-nats/src/service/mod.rs index 75dd4ef02..fe926fc85 100644 --- a/async-nats/src/service/mod.rs +++ b/async-nats/src/service/mod.rs @@ -93,10 +93,10 @@ pub struct Info { pub description: Option, /// Service version. pub version: String, - /// All service endpoints. - pub subjects: Vec, /// Additional metadata pub metadata: HashMap, + /// Info about all service endpoints. + pub endpoints: Vec, } /// Configuration of the [Service]. @@ -322,6 +322,10 @@ impl Service { "service name is not a valid string (only A-Z, a-z, 0-9, _, - are allowed)", ))); } + let endpoints_state = Arc::new(Mutex::new(Endpoints { + endpoints: HashMap::new(), + })); + let queue_group = config .queue_group .unwrap_or(DEFAULT_QUEUE_GROUP.to_string()); @@ -334,15 +338,12 @@ impl Service { id: id.clone(), description: config.description.clone(), version: config.version.clone(), - subjects: Vec::default(), metadata: config.metadata.clone().unwrap_or_default(), + endpoints: Vec::new(), }; let (shutdown_tx, _) = tokio::sync::broadcast::channel(1); - let endpoints = HashMap::new(); - let endpoints_state = Arc::new(Mutex::new(Endpoints { endpoints })); - // create subscriptions for all verbs. let mut pings = verb_subscription(client.clone(), Verb::Ping, config.name.clone(), id.clone()).await?; @@ -355,7 +356,6 @@ impl Service { let handle = tokio::task::spawn({ let mut stats_callback = config.stats_handler; let info = info.clone(); - let subjects = subjects.clone(); let endpoints_state = endpoints_state.clone(); let client = client.clone(); async move { @@ -371,10 +371,20 @@ impl Service { client.publish(ping.reply.unwrap(), pong.into()).await?; }, Some(info_request) = infos.next() => { - let subjects = subjects.clone(); let info = info.clone(); + + let endpoints: Vec = { + endpoints_state.lock().unwrap().endpoints.values().map(|value| { + endpoint::Info { + name: value.name.to_owned(), + subject: value.subject.to_owned(), + queue_group: value.queue_group.to_owned(), + metadata: value.metadata.to_owned() + } + }).collect() + }; let info = Info { - subjects: subjects.lock().unwrap().to_vec(), + endpoints, ..info }; let info_json = serde_json::to_vec(&info).map(Bytes::from)?; diff --git a/async-nats/tests/service_tests.rs b/async-nats/tests/service_tests.rs index 2aa2597d2..b4252ca9d 100644 --- a/async-nats/tests/service_tests.rs +++ b/async-nats/tests/service_tests.rs @@ -469,6 +469,45 @@ mod service { assert_eq!(responses.take(2).count().await, 2); } + #[tokio::test] + async fn info() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let service = client + .service_builder() + .start("service", "1.0.0") + .await + .unwrap(); + + let endpoint_info = service::endpoint::Info { + name: "endpoint_1".to_string(), + subject: "subject".to_string(), + queue_group: "queue".to_string(), + metadata: HashMap::from([("key".to_string(), "value".to_string())]), + }; + + service + .endpoint_builder() + .name(&endpoint_info.name) + .metadata(endpoint_info.metadata.clone()) + .queue_group(&endpoint_info.queue_group) + .add(&endpoint_info.subject) + .await + .unwrap(); + + let info: service::Info = serde_json::from_slice( + &client + .request("$SRV.INFO".into(), "".into()) + .await + .unwrap() + .payload, + ) + .unwrap(); + + assert_eq!(&endpoint_info, info.endpoints.first().unwrap()); + } + #[tokio::test] #[cfg(not(target_os = "windows"))] async fn cross_clients_tests() {