From b807150bcb891f86634f6e7820cf8257a5aa3690 Mon Sep 17 00:00:00 2001 From: Shubham Patel Date: Sat, 4 Mar 2023 11:04:25 +0530 Subject: [PATCH] Add optional boundaries to getregistrationreceipt --- watchtower-plugin/src/convert.rs | 111 +++++++++++++++++++++++++++++ watchtower-plugin/src/dbm.rs | 91 +++++++++++++++-------- watchtower-plugin/src/main.rs | 31 +++++--- watchtower-plugin/src/wt_client.rs | 17 ++++- 4 files changed, 207 insertions(+), 43 deletions(-) diff --git a/watchtower-plugin/src/convert.rs b/watchtower-plugin/src/convert.rs index ec1b9d01..a2acd576 100644 --- a/watchtower-plugin/src/convert.rs +++ b/watchtower-plugin/src/convert.rs @@ -258,6 +258,117 @@ impl TryFrom for GetAppointmentParams { } } +// Errors related to `getregistrationreceipt` command +#[derive(Debug)] +pub enum GetRegistrationReceiptError { + InvalidId(String), + InvalidFormat(String), +} + +impl std::fmt::Display for GetRegistrationReceiptError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + GetRegistrationReceiptError::InvalidId(x) => write!(f, "{x}"), + GetRegistrationReceiptError::InvalidFormat(x) => write!(f, "{x}"), + } + } +} + +// Parameters related to the `getregistrationreceipt` command +#[derive(Debug)] +pub struct GetRegistrationReceiptParams { + pub tower_id: TowerId, + pub subscription_start: Option, + pub subscription_expiry: Option, +} + +impl TryFrom for GetRegistrationReceiptParams { + type Error = GetRegistrationReceiptError; + + fn try_from(value: serde_json::Value) -> Result { + match value { + serde_json::Value::Array(a) => { + let param_count = a.len(); + if param_count == 2 { + Err(GetRegistrationReceiptError::InvalidFormat(( + "Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string() + )) + } else if param_count != 1 && param_count != 3 { + Err(GetRegistrationReceiptError::InvalidFormat(format!( + "Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}" + ))) + } else { + let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { + TowerId::from_str(s).map_err(|_| { + GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned()) + }) + } else { + Err(GetRegistrationReceiptError::InvalidId( + "tower_id must be a hex encoded string".to_owned(), + )) + }?; + + let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1), a.get(2)){ + let start = start.as_i64().ok_or_else(|| { + GetRegistrationReceiptError::InvalidFormat( + "Subscription_start must be a positive integer".to_owned(), + ) + })?; + + let expire = expire.as_i64().ok_or_else(|| { + GetRegistrationReceiptError::InvalidFormat( + "Subscription_expire must be a positive integer".to_owned(), + ) + })?; + + if start >= 0 && expire > start { + (Some(start as u32), Some(expire as u32)) + } else { + return Err(GetRegistrationReceiptError::InvalidFormat( + "subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(), + )); + } + } else { + (None, None) + }; + + Ok( + Self { + tower_id, + subscription_start, + subscription_expiry, + } + ) + } + }, + serde_json::Value::Object(mut m) => { + let allowed_keys = ["tower_id", "subscription_start", "subscription_expiry"]; + let param_count = m.len(); + + if m.is_empty() || param_count > allowed_keys.len() { + Err(GetRegistrationReceiptError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {param_count}"))) + } else if !m.contains_key(allowed_keys[0]){ + Err(GetRegistrationReceiptError::InvalidId(format!("{} is mandatory", allowed_keys[0]))) + } else if !m.iter().all(|(k, _)| allowed_keys.contains(&k.as_str())) { + Err(GetRegistrationReceiptError::InvalidFormat("Invalid named parameter found in request".to_owned())) + } else { + let mut params = Vec::with_capacity(allowed_keys.len()); + for k in allowed_keys { + if let Some(v) = m.remove(k) { + params.push(v); + } + } + + GetRegistrationReceiptParams::try_from(json!(params)) + } + }, + _ => Err(GetRegistrationReceiptError::InvalidFormat(format!( + "Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'" + ))), + } + } +} + /// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook. #[derive(Debug, Serialize, Deserialize)] pub struct CommitmentRevocation { diff --git a/watchtower-plugin/src/dbm.rs b/watchtower-plugin/src/dbm.rs index bf271bb8..b607e37a 100755 --- a/watchtower-plugin/src/dbm.rs +++ b/watchtower-plugin/src/dbm.rs @@ -3,7 +3,7 @@ use std::iter::FromIterator; use std::path::PathBuf; use std::str::FromStr; -use rusqlite::{params, Connection, Error as SqliteError}; +use rusqlite::{params, Connection, Error as SqliteError, ToSql}; use bitcoin::secp256k1::SecretKey; @@ -209,36 +209,43 @@ impl DBM { Some(tower) } - /// Loads the latest registration receipt for a given tower. - /// + /// Loads the registration receipt(s) for a given tower in the given subscription range. + /// If no range is given, then loads the latest receipt /// Latests is determined by the one with the `subscription_expiry` further into the future. pub fn load_registration_receipt( &self, tower_id: TowerId, user_id: UserId, - ) -> Option { - let mut stmt = self - .connection - .prepare( - "SELECT available_slots, subscription_start, subscription_expiry, signature - FROM registration_receipts - WHERE tower_id = ?1 AND subscription_expiry = (SELECT MAX(subscription_expiry) - FROM registration_receipts - WHERE tower_id = ?1)", - ) - .unwrap(); + subscription_start: Option, + subscription_expiry: Option, + ) -> Option> { + let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string(); + + let tower_id_encoded = tower_id.to_vec(); + let mut params: Vec<&dyn ToSql> = vec![&tower_id_encoded]; + + if subscription_expiry.is_none() { + query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)") + } else { + query.push_str(" AND subscription_start>=?2 AND subscription_expiry <=?3"); + params.push(&subscription_start); + params.push(&subscription_expiry) + } + let mut stmt = self.connection.prepare(&query).unwrap(); - stmt.query_row([tower_id.to_vec()], |row| { - let slots: u32 = row.get(0).unwrap(); - let start: u32 = row.get(1).unwrap(); - let expiry: u32 = row.get(2).unwrap(); - let signature: String = row.get(3).unwrap(); + stmt.query_map(params.as_slice(), |row| { + let slots: u32 = row.get(0)?; + let start: u32 = row.get(1)?; + let expiry: u32 = row.get(2)?; + let signature: String = row.get(3)?; Ok(RegistrationReceipt::with_signature( user_id, slots, start, expiry, signature, )) }) - .ok() + .unwrap() + .map(|r| r.ok()) + .collect() } /// Removes a tower record from the database. @@ -725,34 +732,49 @@ mod tests { let tower_id = get_random_user_id(); let net_addr = "talaia.watch"; let receipt = get_random_registration_receipt(); + let subscription_start = Some(receipt.subscription_start()); + let subscription_expiry = Some(receipt.subscription_expiry()); // Check the receipt was stored dbm.store_tower_record(tower_id, net_addr, &receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, receipt.user_id()) - .unwrap(), + dbm.load_registration_receipt( + tower_id, + receipt.user_id(), + subscription_start, + subscription_expiry + ) + .unwrap()[0], receipt ); - // Add another receipt for the same tower with a higher expiry and check this last one is loaded + // Add another receipt for the same tower with a higher expiry and check that output gives vector of both receipts let middle_receipt = get_registration_receipt_from_previous(&receipt); let latest_receipt = get_registration_receipt_from_previous(&middle_receipt); + let latest_subscription_expiry = Some(latest_receipt.subscription_expiry()); + dbm.store_tower_record(tower_id, net_addr, &latest_receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, latest_receipt.user_id()) - .unwrap(), - latest_receipt + dbm.load_registration_receipt( + tower_id, + latest_receipt.user_id(), + subscription_start, + latest_subscription_expiry + ) + .unwrap(), + vec![receipt, latest_receipt.clone()] ); - // Add a final one with a lower expiry and check the last is still loaded + // Add a final one with a lower expiry and check if the lastest receipt is loaded when boundry + // params are not passed dbm.store_tower_record(tower_id, net_addr, &middle_receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, latest_receipt.user_id()) - .unwrap(), + dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None) + .unwrap()[0], latest_receipt ); } @@ -765,13 +787,20 @@ mod tests { let tower_id = get_random_user_id(); let net_addr = "talaia.watch"; let receipt = get_random_registration_receipt(); + let subscription_start = Some(receipt.subscription_start()); + let subscription_expiry = Some(receipt.subscription_expiry()); // Store it once dbm.store_tower_record(tower_id, net_addr, &receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, receipt.user_id()) - .unwrap(), + dbm.load_registration_receipt( + tower_id, + receipt.user_id(), + subscription_start, + subscription_expiry + ) + .unwrap()[0], receipt ); diff --git a/watchtower-plugin/src/main.rs b/watchtower-plugin/src/main.rs index 5778e8d0..3a3594c3 100755 --- a/watchtower-plugin/src/main.rs +++ b/watchtower-plugin/src/main.rs @@ -18,7 +18,9 @@ use teos_common::protos as common_msgs; use teos_common::TowerId; use teos_common::{cryptography, errors}; -use watchtower_plugin::convert::{CommitmentRevocation, GetAppointmentParams, RegisterParams}; +use watchtower_plugin::convert::{ + CommitmentRevocation, GetAppointmentParams, GetRegistrationReceiptParams, RegisterParams, +}; use watchtower_plugin::net::http::{ self, get_request, post_request, process_post_response, AddAppointmentError, ApiResponse, RequestError, @@ -127,22 +129,33 @@ async fn register( Ok(json!(receipt)) } -/// Gets the latest registration receipt from the client to a given tower (if it exists). -/// +/// Gets the registration receipt(s) from the client to a given tower (if it exists) in the given +/// range. If no range is given, then gets the latest registration receipt. /// This is pulled from the database async fn get_registration_receipt( plugin: Plugin>>, v: serde_json::Value, ) -> Result { - let tower_id = TowerId::try_from(v).map_err(|x| anyhow!(x))?; + let params = GetRegistrationReceiptParams::try_from(v).map_err(|x| anyhow!(x))?; + let tower_id = params.tower_id; + let subscription_start = params.subscription_start; + let subscription_expiry = params.subscription_expiry; let state = plugin.state().lock().unwrap(); - if let Some(response) = state.get_registration_receipt(tower_id) { - Ok(json!(response)) + let response = + state.get_registration_receipt(tower_id, subscription_start, subscription_expiry); + if response.clone().unwrap().is_empty() { + if state.towers.contains_key(&tower_id) { + Err(anyhow!( + "No registration receipt found for {tower_id} on the given range" + )) + } else { + Err(anyhow!( + "Cannot find {tower_id} within the known towers. Have you registered?" + )) + } } else { - Err(anyhow!( - "Cannot find {tower_id} within the known towers. Have you registered?" - )) + Ok(json!(response)) } } diff --git a/watchtower-plugin/src/wt_client.rs b/watchtower-plugin/src/wt_client.rs index ecf26db0..d3223242 100644 --- a/watchtower-plugin/src/wt_client.rs +++ b/watchtower-plugin/src/wt_client.rs @@ -179,9 +179,20 @@ impl WTClient { Ok(()) } - /// Gets the latest registration receipt of a given tower. - pub fn get_registration_receipt(&self, tower_id: TowerId) -> Option { - self.dbm.load_registration_receipt(tower_id, self.user_id) + /// Gets the registration receipt(s) of a given tower in the given range. + /// If no range is given then gets the latest registration receipt + pub fn get_registration_receipt( + &self, + tower_id: TowerId, + subscription_start: Option, + subscription_expiry: Option, + ) -> Option> { + self.dbm.load_registration_receipt( + tower_id, + self.user_id, + subscription_start, + subscription_expiry, + ) } /// Loads a tower record from the database.