From 47da3c8e07612d193cf66f2648a9873b4b9b38f3 Mon Sep 17 00:00:00 2001 From: Shubham Patel Date: Thu, 16 Mar 2023 19:32:17 +0530 Subject: [PATCH] second update #199 --- watchtower-plugin/src/constants.rs | 1 + watchtower-plugin/src/convert.rs | 91 ++++++++++++++++++++---------- watchtower-plugin/src/dbm.rs | 59 ++++++++----------- 3 files changed, 86 insertions(+), 65 deletions(-) diff --git a/watchtower-plugin/src/constants.rs b/watchtower-plugin/src/constants.rs index 22ae22af..1b14803d 100644 --- a/watchtower-plugin/src/constants.rs +++ b/watchtower-plugin/src/constants.rs @@ -3,6 +3,7 @@ pub const TOWERS_DATA_DIR: &str = "TOWERS_DATA_DIR"; pub const DEFAULT_TOWERS_DATA_DIR: &str = ".watchtower"; /// Collections of plugin option names, default values and descriptions + pub const WT_PORT: &str = "watchtower-port"; pub const DEFAULT_WT_PORT: i64 = 9814; pub const WT_PORT_DESC: &str = "tower API port"; diff --git a/watchtower-plugin/src/convert.rs b/watchtower-plugin/src/convert.rs index 4cf1031f..96bdfe77 100644 --- a/watchtower-plugin/src/convert.rs +++ b/watchtower-plugin/src/convert.rs @@ -199,8 +199,8 @@ impl TryFrom for GetAppointmentParams { let param_count = a.len(); if param_count != 2 { Err(GetAppointmentError::InvalidFormat(format!( - "Unexpected request format. The request needs 2 parameter. Received: {param_count}" - ))) + "Unexpected request format. The request needs 2 parameter. Received: {param_count}" + ))) } else { let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { TowerId::from_str(s).map_err(|_| { @@ -288,43 +288,75 @@ impl TryFrom for GetRegistrationReceiptParams { fn try_from(value: serde_json::Value) -> Result { match value { serde_json::Value::Array(a) => { - let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { - TowerId::from_str(s).map_err(|_| { + let param_count = a.len(); + if param_count != 1 && param_count != 3 { + Err(GetRegistrationReceiptError::InvalidFormat(format!( + "Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}" + ))) + } else{ + let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { + TowerId::from_str(s).map_err(|_| { GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned()) - }) - } else { - Err(GetRegistrationReceiptError::InvalidId( - "tower_id must be a hex encoded string".to_owned(), - )) - }?; - let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) { - if start >= 0 { - Some(start as u32) + }) } else { - return Err(GetRegistrationReceiptError::InvalidFormat( + Err(GetRegistrationReceiptError::InvalidId( + "tower_id must be a hex encoded string".to_owned(), + )) + }?; + let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) { + if start >= 0 { + Some(start as u32) + } else { + return Err(GetRegistrationReceiptError::InvalidFormat( "Subscription-start must be a positive integer".to_owned(), - )); - } - } else { - None - }; - let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) { - if expire >= 0 { + )); + } + } else { + None + }; + let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) { + if expire > subscription_start.unwrap() as i64 { Some(expire as u32) + } else { + return Err(GetRegistrationReceiptError::InvalidFormat( + "Subscription-expire must be a positive integer and greater than subscription_start".to_owned(), + )); + } } else { + None + }; + if subscription_start.is_some() != subscription_expiry.is_some() { return Err(GetRegistrationReceiptError::InvalidFormat( - "Subscription-expire must be a positive integer".to_owned(), - )); + "Subscription-start and subscription-expiry must be provided together".to_owned(), + )); } - } else { - None - }; - Ok(Self { + Ok(Self { tower_id, subscription_start, subscription_expiry, - }) - } + }) + } + }, + serde_json::Value::Object(mut m) => { + let allowed_keys = ["tower_id", "subscription_start", "subscription_expiry"]; + let param_count = m.len(); + + if m.is_empty() || param_count > allowed_keys.len() { + Err(GetRegistrationReceiptError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {param_count}"))) + } else if !m.contains_key(allowed_keys[0]){ + Err(GetRegistrationReceiptError::InvalidId(format!("{} is mandatory", allowed_keys[0]))) + } else if !m.iter().all(|(k, _)| allowed_keys.contains(&k.as_str())) { + Err(GetRegistrationReceiptError::InvalidFormat("Invalid named parameter found in request".to_owned())) + } else { + let mut params = Vec::with_capacity(allowed_keys.len()); + for k in allowed_keys { + if let Some(v) = m.remove(k) { + params.push(v); + } + } + GetRegistrationReceiptParams::try_from(json!(params)) + } + }, _ => Err(GetRegistrationReceiptError::InvalidFormat(format!( "Unexpected request format. Expected: tower_id and optional arguments subscription_start & subscription_expire. Received: '{value}'" ))), @@ -332,7 +364,6 @@ impl TryFrom for GetRegistrationReceiptParams { } } - /// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook. #[derive(Debug, Serialize, Deserialize)] pub struct CommitmentRevocation { diff --git a/watchtower-plugin/src/dbm.rs b/watchtower-plugin/src/dbm.rs index a3d04ba1..e58c9585 100755 --- a/watchtower-plugin/src/dbm.rs +++ b/watchtower-plugin/src/dbm.rs @@ -3,7 +3,7 @@ use std::iter::FromIterator; use std::path::PathBuf; use std::str::FromStr; -use rusqlite::{params, Connection, Error as SqliteError, ToSql}; +use rusqlite::{params, Connection, Error as SqliteError}; use bitcoin::secp256k1::SecretKey; @@ -219,38 +219,27 @@ impl DBM { subscription_start: Option, subscription_expiry: Option, ) -> Option { - let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string(); - - let mut params = vec![tower_id.to_vec()]; - - if let Some(start) = subscription_start { - query.push_str(" AND subscription_start >= ?2"); - params.push(start.to_be_bytes().to_vec()); - } else { - query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)"); - } - - if let Some(expiry) = subscription_expiry { - query.push_str(" AND subscription_expiry <= ?3"); - params.push(expiry.to_be_bytes().to_vec()); - } - - //query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1"); + let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1 AND (subscription_start >=?2 OR ?2 is NULL) AND (subscription_expiry <=?3 OR ?3 is NULL)".to_string(); + if subscription_expiry == None { + query.push_str(" OR subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)") + }; let mut stmt = self.connection.prepare(&query).unwrap(); - let params: Vec<&dyn ToSql> = params.iter().map(|v| v as &dyn ToSql).collect(); - stmt.query_row(params.as_slice(), |row| { - let slots: u32 = row.get(0).unwrap(); - let start: u32 = row.get(1).unwrap(); - let expiry: u32 = row.get(2).unwrap(); - let signature: String = row.get(3).unwrap(); + stmt.query_row( + params![tower_id.to_vec(), subscription_start, subscription_expiry], + |row| { + let slots: u32 = row.get(0).unwrap(); + let start: u32 = row.get(1).unwrap(); + let expiry: u32 = row.get(2).unwrap(); + let signature: String = row.get(3).unwrap(); - Ok(RegistrationReceipt::with_signature( - user_id, slots, start, expiry, signature, - )) - }). - ok() + Ok(RegistrationReceipt::with_signature( + user_id, slots, start, expiry, signature, + )) + }, + ) + .ok() } /// Removes a tower record from the database. @@ -661,8 +650,8 @@ mod tests { use teos_common::cryptography::get_random_keypair; use teos_common::test_utils::{ - generate_random_appointment, get_random_registration_receipt, get_random_user_id, - get_registration_receipt_from_previous, + generate_random_appointment, get_random_int, get_random_registration_receipt, + get_random_user_id, get_registration_receipt_from_previous, }; impl DBM { @@ -737,8 +726,8 @@ mod tests { let tower_id = get_random_user_id(); let net_addr = "talaia.watch"; let receipt = get_random_registration_receipt(); - let subscription_start = None; - let subscription_expiry = None; + let subscription_start = Some(receipt.subscription_start()); + let subscription_expiry = Some(receipt.subscription_expiry()); // Check the receipt was stored dbm.store_tower_record(tower_id, net_addr, &receipt) @@ -794,8 +783,8 @@ mod tests { let tower_id = get_random_user_id(); let net_addr = "talaia.watch"; let receipt = get_random_registration_receipt(); - let subscription_start = None; - let subscription_expiry = None; + let subscription_start = get_random_int(); + let subscription_expiry = get_random_int(); // Store it once dbm.store_tower_record(tower_id, net_addr, &receipt)