Skip to content

Commit

Permalink
feat!: add context to checks
Browse files Browse the repository at this point in the history
BREAKING CHANGE

Signed-off-by: Gustavo Inacio <[email protected]>
  • Loading branch information
gusinacio committed Oct 30, 2024
1 parent f49c21d commit 58a6a52
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 49 deletions.
1 change: 1 addition & 0 deletions tap_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ anyhow.workspace = true
rand.workspace = true
thiserror = "1.0.38"
async-trait = "0.1.72"
anymap3 = "1.0.0"

[dev-dependencies]
criterion = { version = "0.5", features = ["async_std"] }
Expand Down
6 changes: 3 additions & 3 deletions tap_core/src/manager/context/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ pub mod checks {
receipt::{
checks::{Check, CheckError, CheckResult, ReceiptCheck},
state::Checking,
ReceiptError, ReceiptWithState,
Context, ReceiptError, ReceiptWithState,
},
signed_message::MessageId,
};
Expand Down Expand Up @@ -296,7 +296,7 @@ pub mod checks {

#[async_trait::async_trait]
impl Check for AllocationIdCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
let received_allocation_id = receipt.signed_receipt().message.allocation_id;
if self
.allocation_ids
Expand All @@ -323,7 +323,7 @@ pub mod checks {

#[async_trait::async_trait]
impl Check for SignatureCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
let recovered_address = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
Expand Down
5 changes: 3 additions & 2 deletions tap_core/src/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
//! ReceiptWithState,
//! state::Checking,
//! checks::CheckList,
//! ReceiptError
//! ReceiptError,
//! Context
//! },
//! manager::{
//! Manager,
Expand Down Expand Up @@ -70,7 +71,7 @@
//! let receipt = EIP712SignedMessage::new(&domain_separator, message, &wallet).unwrap();
//!
//! let manager = Manager::new(domain_separator, MyContext, CheckList::empty());
//! manager.verify_and_store_receipt(receipt).await.unwrap()
//! manager.verify_and_store_receipt(&Context::new(), receipt).await.unwrap()
//! # }
//! ```
//!
Expand Down
11 changes: 7 additions & 4 deletions tap_core/src/manager/tap_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
receipt::{
checks::{CheckBatch, CheckList, TimestampCheck, UniqueCheck},
state::{Failed, Reserved},
ReceiptError, ReceiptWithState, SignedReceipt,
Context, ReceiptError, ReceiptWithState, SignedReceipt,
},
Error,
};
Expand Down Expand Up @@ -99,6 +99,7 @@ where
{
async fn collect_receipts(
&self,
ctx: &Context,
timestamp_buffer_ns: u64,
min_timestamp_ns: u64,
limit: Option<u64>,
Expand Down Expand Up @@ -140,7 +141,7 @@ where

for receipt in checking_receipts.into_iter() {
let receipt = receipt
.finalize_receipt_checks(&self.checks)
.finalize_receipt_checks(ctx, &self.checks)
.await
.map_err(|e| Error::ReceiptError(ReceiptError::RetryableCheck(e)))?;

Expand Down Expand Up @@ -184,6 +185,7 @@ where
///
pub async fn create_rav_request(
&self,
ctx: &Context,
timestamp_buffer_ns: u64,
receipts_limit: Option<u64>,
) -> Result<RAVRequest, Error> {
Expand All @@ -194,7 +196,7 @@ where
.unwrap_or(0);

let (valid_receipts, invalid_receipts) = self
.collect_receipts(timestamp_buffer_ns, min_timestamp_ns, receipts_limit)
.collect_receipts(ctx, timestamp_buffer_ns, min_timestamp_ns, receipts_limit)
.await?;

let expected_rav = Self::generate_expected_rav(&valid_receipts, previous_rav.clone());
Expand Down Expand Up @@ -271,12 +273,13 @@ where
///
pub async fn verify_and_store_receipt(
&self,
ctx: &Context,
signed_receipt: SignedReceipt,
) -> std::result::Result<(), Error> {
let mut received_receipt = ReceiptWithState::new(signed_receipt);

// perform checks
received_receipt.perform_checks(&self.checks).await?;
received_receipt.perform_checks(ctx, &self.checks).await?;

// store the receipt
self.context
Expand Down
2 changes: 1 addition & 1 deletion tap_core/src/rav.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
//! 1. Create a [`RAVRequest`] with the valid receipts and the previous RAV.
//! 2. Send the request to the aggregator.
//! 3. The aggregator will verify the request and increment the total amount that
//! has been aggregated.
//! has been aggregated.
//! 4. The aggregator will return a [`SignedRAV`].
//! 5. Store the [`SignedRAV`].
//! 6. Repeat the process until the allocation is closed.
Expand Down
10 changes: 5 additions & 5 deletions tap_core/src/receipt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
//! # use std::sync::Arc;
//! use tap_core::{
//! receipt::checks::{Check, CheckResult, ReceiptCheck},
//! receipt::{ReceiptWithState, state::Checking}
//! receipt::{Context, ReceiptWithState, state::Checking}
//! };
//! # use async_trait::async_trait;
//!
//! struct MyCheck;
//!
//! #[async_trait]
//! impl Check for MyCheck {
//! async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
//! async fn check(&self, ctx: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
//! // Implement your check here
//! Ok(())
//! }
Expand All @@ -33,7 +33,7 @@ use crate::signed_message::{SignatureBytes, SignatureBytesExt};

use super::{
state::{Checking, Failed},
ReceiptError, ReceiptWithState,
Context, ReceiptError, ReceiptWithState,
};
use std::{
collections::HashSet,
Expand Down Expand Up @@ -80,7 +80,7 @@ impl Deref for CheckList {
/// Check trait is implemented by the lib user to validate receipts before they are stored.
#[async_trait::async_trait]
pub trait Check {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult;
async fn check(&self, ctx: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult;
}

/// CheckBatch is mostly used by the lib to implement checks
Expand Down Expand Up @@ -119,7 +119,7 @@ impl StatefulTimestampCheck {

#[async_trait::async_trait]
impl Check for StatefulTimestampCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
let min_timestamp_ns = *self.min_timestamp_ns.read().unwrap();
let signed_receipt = receipt.signed_receipt();
if signed_receipt.message.timestamp_ns <= min_timestamp_ns {
Expand Down
2 changes: 2 additions & 0 deletions tap_core/src/receipt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ pub type SignedReceipt = EIP712SignedMessage<Receipt>;

/// Result type for receipt
pub type ReceiptResult<T> = Result<T, ReceiptError>;

pub type Context = anymap3::Map<dyn std::any::Any + Send + Sync>;
24 changes: 14 additions & 10 deletions tap_core/src/receipt/received_receipt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
use alloy::dyn_abi::Eip712Domain;

use super::checks::CheckError;
use super::{Receipt, ReceiptError, ReceiptResult, SignedReceipt};
use super::{Context, Receipt, ReceiptError, ReceiptResult, SignedReceipt};
use crate::receipt::state::{AwaitingReserve, Checking, Failed, ReceiptState, Reserved};
use crate::{
manager::adapters::EscrowHandler, receipt::checks::ReceiptCheck,
Expand All @@ -28,16 +28,15 @@ pub type ResultReceipt<S> = std::result::Result<ReceiptWithState<S>, ReceiptWith
/// Typestate pattern for tracking the state of a receipt
///
/// - The [ `ReceiptState` ] trait represents the different states a receipt
/// can be in.
/// can be in.
/// - The [ `Checking` ] state is used to represent a receipt that is currently
/// being checked.
/// being checked.
/// - The [ `Failed` ] state is used to represent a receipt that has failed a
/// check or validation.
/// check or validation.
/// - The [ `AwaitingReserve` ] state is used to represent a receipt that has
/// passed all checks and is
/// awaiting escrow reservation.
/// passed all checks and is awaiting escrow reservation.
/// - The [ `Reserved` ] state is used to represent a receipt that has
/// successfully reserved escrow.
/// successfully reserved escrow.
#[derive(Debug, Clone)]
pub struct ReceiptWithState<S>
where
Expand Down Expand Up @@ -90,10 +89,14 @@ impl ReceiptWithState<Checking> {
/// cannot be comleted in the receipts current internal state.
/// All other checks must be complete before `CheckAndReserveEscrow`.
///
pub async fn perform_checks(&mut self, checks: &[ReceiptCheck]) -> ReceiptResult<()> {
pub async fn perform_checks(
&mut self,
ctx: &Context,
checks: &[ReceiptCheck],
) -> ReceiptResult<()> {
for check in checks {
// return early on an error
check.check(self).await.map_err(|e| match e {
check.check(ctx, self).await.map_err(|e| match e {
CheckError::Retryable(e) => ReceiptError::RetryableCheck(e.to_string()),
CheckError::Failed(e) => ReceiptError::CheckFailure(e.to_string()),
})?;
Expand All @@ -108,9 +111,10 @@ impl ReceiptWithState<Checking> {
///
pub async fn finalize_receipt_checks(
mut self,
ctx: &Context,
checks: &[ReceiptCheck],
) -> Result<ResultReceipt<AwaitingReserve>, String> {
let all_checks_passed = self.perform_checks(checks).await;
let all_checks_passed = self.perform_checks(ctx, checks).await;
if let Err(ReceiptError::RetryableCheck(e)) = all_checks_passed {
Err(e.to_string())
} else if let Err(e) = all_checks_passed {
Expand Down
41 changes: 24 additions & 17 deletions tap_core/tests/manager_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use tap_core::{
receipt::{
checks::{Check, CheckError, CheckList, StatefulTimestampCheck},
state::Checking,
Receipt, ReceiptWithState,
Context, Receipt, ReceiptWithState,
},
signed_message::EIP712SignedMessage,
tap_eip712_domain,
Expand Down Expand Up @@ -145,7 +145,7 @@ async fn manager_verify_and_store_varying_initial_checks(
.insert(signer.address(), 999999);

assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
}
Expand Down Expand Up @@ -184,11 +184,11 @@ async fn manager_create_rav_request_all_valid_receipts(
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
}
let rav_request_result = manager.create_rav_request(0, None).await;
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_result.is_ok());

let rav_request = rav_request_result.unwrap();
Expand Down Expand Up @@ -279,12 +279,12 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
}
let rav_request_result = manager.create_rav_request(0, None).await;
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_result.is_ok());

let rav_request = rav_request_result.unwrap();
Expand Down Expand Up @@ -323,12 +323,12 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
}
let rav_request_result = manager.create_rav_request(0, None).await;
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_result.is_ok());

let rav_request = rav_request_result.unwrap();
Expand Down Expand Up @@ -391,7 +391,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
Expand All @@ -403,7 +403,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
manager.remove_obsolete_receipts().await.unwrap();
}

let rav_request_1_result = manager.create_rav_request(0, None).await;
let rav_request_1_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_1_result.is_ok());

let rav_request_1 = rav_request_1_result.unwrap();
Expand Down Expand Up @@ -438,7 +438,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
Expand All @@ -458,7 +458,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
);
}

let rav_request_2_result = manager.create_rav_request(0, None).await;
let rav_request_2_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_2_result.is_ok());

let rav_request_2 = rav_request_2_result.unwrap();
Expand Down Expand Up @@ -518,12 +518,15 @@ async fn manager_create_rav_and_ignore_invalid_receipts(
let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &signer).unwrap();
stored_signed_receipts.push(signed_receipt.clone());
manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.unwrap();
}

let rav_request = manager.create_rav_request(0, None).await.unwrap();
let rav_request = manager
.create_rav_request(&Context::new(), 0, None)
.await
.unwrap();
let expected_rav = rav_request.expected_rav.unwrap();

assert_eq!(rav_request.valid_receipts.len(), 1);
Expand All @@ -544,7 +547,11 @@ async fn test_retryable_checks(

#[async_trait::async_trait]
impl Check for RetryableCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> Result<(), CheckError> {
async fn check(
&self,
_: &Context,
receipt: &ReceiptWithState<Checking>,
) -> Result<(), CheckError> {
// we want to fail only if nonce is 5 and if is create rav step
if self.0.load(std::sync::atomic::Ordering::SeqCst)
&& receipt.signed_receipt().message.nonce == 5
Expand Down Expand Up @@ -591,14 +598,14 @@ async fn test_retryable_checks(
let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &signer).unwrap();
stored_signed_receipts.push(signed_receipt.clone());
manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.unwrap();
}

is_create_rav.store(true, std::sync::atomic::Ordering::SeqCst);

let rav_request = manager.create_rav_request(0, None).await;
let rav_request = manager.create_rav_request(&Context::new(), 0, None).await;

assert_eq!(
rav_request.expect_err("Didn't fail").to_string(),
Expand Down
Loading

0 comments on commit 58a6a52

Please sign in to comment.