Skip to content

Commit

Permalink
Add optional boundaries to getregistrationreceipt
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhamBhut committed Oct 10, 2023
1 parent a4acced commit b807150
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 43 deletions.
111 changes: 111 additions & 0 deletions watchtower-plugin/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,117 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
}
}

// Errors related to `getregistrationreceipt` command
#[derive(Debug)]
pub enum GetRegistrationReceiptError {
InvalidId(String),
InvalidFormat(String),
}

impl std::fmt::Display for GetRegistrationReceiptError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
GetRegistrationReceiptError::InvalidId(x) => write!(f, "{x}"),
GetRegistrationReceiptError::InvalidFormat(x) => write!(f, "{x}"),
}
}
}

// Parameters related to the `getregistrationreceipt` command
#[derive(Debug)]
pub struct GetRegistrationReceiptParams {
pub tower_id: TowerId,
pub subscription_start: Option<u32>,
pub subscription_expiry: Option<u32>,
}

impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
type Error = GetRegistrationReceiptError;

fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
match value {
serde_json::Value::Array(a) => {
let param_count = a.len();
if param_count == 2 {
Err(GetRegistrationReceiptError::InvalidFormat((
"Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()
))
} else if param_count != 1 && param_count != 3 {
Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
)))
} else {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
})
} else {
Err(GetRegistrationReceiptError::InvalidId(
"tower_id must be a hex encoded string".to_owned(),
))
}?;

let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1), a.get(2)){
let start = start.as_i64().ok_or_else(|| {
GetRegistrationReceiptError::InvalidFormat(
"Subscription_start must be a positive integer".to_owned(),
)
})?;

let expire = expire.as_i64().ok_or_else(|| {
GetRegistrationReceiptError::InvalidFormat(
"Subscription_expire must be a positive integer".to_owned(),
)
})?;

if start >= 0 && expire > start {
(Some(start as u32), Some(expire as u32))
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
));
}
} else {
(None, None)
};

Ok(
Self {
tower_id,
subscription_start,
subscription_expiry,
}
)
}
},
serde_json::Value::Object(mut m) => {
let allowed_keys = ["tower_id", "subscription_start", "subscription_expiry"];
let param_count = m.len();

if m.is_empty() || param_count > allowed_keys.len() {
Err(GetRegistrationReceiptError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {param_count}")))
} else if !m.contains_key(allowed_keys[0]){
Err(GetRegistrationReceiptError::InvalidId(format!("{} is mandatory", allowed_keys[0])))
} else if !m.iter().all(|(k, _)| allowed_keys.contains(&k.as_str())) {
Err(GetRegistrationReceiptError::InvalidFormat("Invalid named parameter found in request".to_owned()))
} else {
let mut params = Vec::with_capacity(allowed_keys.len());
for k in allowed_keys {
if let Some(v) = m.remove(k) {
params.push(v);
}
}

GetRegistrationReceiptParams::try_from(json!(params))
}
},
_ => Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'"
))),
}
}
}

/// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook.
#[derive(Debug, Serialize, Deserialize)]
pub struct CommitmentRevocation {
Expand Down
91 changes: 60 additions & 31 deletions watchtower-plugin/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::iter::FromIterator;
use std::path::PathBuf;
use std::str::FromStr;

use rusqlite::{params, Connection, Error as SqliteError};
use rusqlite::{params, Connection, Error as SqliteError, ToSql};

use bitcoin::secp256k1::SecretKey;

Expand Down Expand Up @@ -209,36 +209,43 @@ impl DBM {
Some(tower)
}

/// Loads the latest registration receipt for a given tower.
///
/// Loads the registration receipt(s) for a given tower in the given subscription range.
/// If no range is given, then loads the latest receipt
/// Latests is determined by the one with the `subscription_expiry` further into the future.
pub fn load_registration_receipt(
&self,
tower_id: TowerId,
user_id: UserId,
) -> Option<RegistrationReceipt> {
let mut stmt = self
.connection
.prepare(
"SELECT available_slots, subscription_start, subscription_expiry, signature
FROM registration_receipts
WHERE tower_id = ?1 AND subscription_expiry = (SELECT MAX(subscription_expiry)
FROM registration_receipts
WHERE tower_id = ?1)",
)
.unwrap();
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<Vec<RegistrationReceipt>> {
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();

let tower_id_encoded = tower_id.to_vec();
let mut params: Vec<&dyn ToSql> = vec![&tower_id_encoded];

if subscription_expiry.is_none() {
query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
} else {
query.push_str(" AND subscription_start>=?2 AND subscription_expiry <=?3");
params.push(&subscription_start);
params.push(&subscription_expiry)
}
let mut stmt = self.connection.prepare(&query).unwrap();

stmt.query_row([tower_id.to_vec()], |row| {
let slots: u32 = row.get(0).unwrap();
let start: u32 = row.get(1).unwrap();
let expiry: u32 = row.get(2).unwrap();
let signature: String = row.get(3).unwrap();
stmt.query_map(params.as_slice(), |row| {
let slots: u32 = row.get(0)?;
let start: u32 = row.get(1)?;
let expiry: u32 = row.get(2)?;
let signature: String = row.get(3)?;

Ok(RegistrationReceipt::with_signature(
user_id, slots, start, expiry, signature,
))
})
.ok()
.unwrap()
.map(|r| r.ok())
.collect()
}

/// Removes a tower record from the database.
Expand Down Expand Up @@ -725,34 +732,49 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = Some(receipt.subscription_start());
let subscription_expiry = Some(receipt.subscription_expiry());

// Check the receipt was stored
dbm.store_tower_record(tower_id, net_addr, &receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap()[0],
receipt
);

// Add another receipt for the same tower with a higher expiry and check this last one is loaded
// Add another receipt for the same tower with a higher expiry and check that output gives vector of both receipts
let middle_receipt = get_registration_receipt_from_previous(&receipt);
let latest_receipt = get_registration_receipt_from_previous(&middle_receipt);

let latest_subscription_expiry = Some(latest_receipt.subscription_expiry());

dbm.store_tower_record(tower_id, net_addr, &latest_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
.unwrap(),
latest_receipt
dbm.load_registration_receipt(
tower_id,
latest_receipt.user_id(),
subscription_start,
latest_subscription_expiry
)
.unwrap(),
vec![receipt, latest_receipt.clone()]
);

// Add a final one with a lower expiry and check the last is still loaded
// Add a final one with a lower expiry and check if the lastest receipt is loaded when boundry
// params are not passed
dbm.store_tower_record(tower_id, net_addr, &middle_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)
.unwrap()[0],
latest_receipt
);
}
Expand All @@ -765,13 +787,20 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = Some(receipt.subscription_start());
let subscription_expiry = Some(receipt.subscription_expiry());

// Store it once
dbm.store_tower_record(tower_id, net_addr, &receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap()[0],
receipt
);

Expand Down
31 changes: 22 additions & 9 deletions watchtower-plugin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use teos_common::protos as common_msgs;
use teos_common::TowerId;
use teos_common::{cryptography, errors};

use watchtower_plugin::convert::{CommitmentRevocation, GetAppointmentParams, RegisterParams};
use watchtower_plugin::convert::{
CommitmentRevocation, GetAppointmentParams, GetRegistrationReceiptParams, RegisterParams,
};
use watchtower_plugin::net::http::{
self, get_request, post_request, process_post_response, AddAppointmentError, ApiResponse,
RequestError,
Expand Down Expand Up @@ -127,22 +129,33 @@ async fn register(
Ok(json!(receipt))
}

/// Gets the latest registration receipt from the client to a given tower (if it exists).
///
/// Gets the registration receipt(s) from the client to a given tower (if it exists) in the given
/// range. If no range is given, then gets the latest registration receipt.
/// This is pulled from the database
async fn get_registration_receipt(
plugin: Plugin<Arc<Mutex<WTClient>>>,
v: serde_json::Value,
) -> Result<serde_json::Value, Error> {
let tower_id = TowerId::try_from(v).map_err(|x| anyhow!(x))?;
let params = GetRegistrationReceiptParams::try_from(v).map_err(|x| anyhow!(x))?;
let tower_id = params.tower_id;
let subscription_start = params.subscription_start;
let subscription_expiry = params.subscription_expiry;
let state = plugin.state().lock().unwrap();

if let Some(response) = state.get_registration_receipt(tower_id) {
Ok(json!(response))
let response =
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry);
if response.clone().unwrap().is_empty() {
if state.towers.contains_key(&tower_id) {
Err(anyhow!(
"No registration receipt found for {tower_id} on the given range"
))
} else {
Err(anyhow!(
"Cannot find {tower_id} within the known towers. Have you registered?"
))
}
} else {
Err(anyhow!(
"Cannot find {tower_id} within the known towers. Have you registered?"
))
Ok(json!(response))
}
}

Expand Down
17 changes: 14 additions & 3 deletions watchtower-plugin/src/wt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,20 @@ impl WTClient {
Ok(())
}

/// Gets the latest registration receipt of a given tower.
pub fn get_registration_receipt(&self, tower_id: TowerId) -> Option<RegistrationReceipt> {
self.dbm.load_registration_receipt(tower_id, self.user_id)
/// Gets the registration receipt(s) of a given tower in the given range.
/// If no range is given then gets the latest registration receipt
pub fn get_registration_receipt(
&self,
tower_id: TowerId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<Vec<RegistrationReceipt>> {
self.dbm.load_registration_receipt(
tower_id,
self.user_id,
subscription_start,
subscription_expiry,
)
}

/// Loads a tower record from the database.
Expand Down

0 comments on commit b807150

Please sign in to comment.