Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Filter to TrieWalker #1598

Merged
merged 5 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions portal-bridge/src/bridge/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ impl StateBridge {

let root_hash = evm_db.trie.lock().root_hash()?;
let mut content_idx = 0;
let state_walker = TrieWalker::new(root_hash, evm_db.trie.lock().db.clone())?;
let state_walker = TrieWalker::new(root_hash, evm_db.trie.lock().db.clone(), None)?;
for account_proof in state_walker {
// gossip the account
self.gossip_account(&account_proof, block_hash, content_idx)
Expand Down Expand Up @@ -426,7 +426,7 @@ impl StateBridge {
let account_db = AccountDB::new(address_hash, evm_db.db.clone());
let trie = EthTrie::from(Arc::new(account_db), account.storage_root)?.db;

let storage_walker = TrieWalker::new(account.storage_root, trie)?;
let storage_walker = TrieWalker::new(account.storage_root, trie, None)?;
for storage_proof in storage_walker {
self.gossip_storage(
&account_proof,
Expand Down
1 change: 1 addition & 0 deletions trin-execution/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jsonrpsee = { workspace = true, features = ["async-client", "client", "macros",
lazy_static.workspace = true
parking_lot.workspace = true
prometheus_exporter.workspace = true
rand.workspace = true
rayon = "1.10.0"
reqwest = { workspace = true, features = ["stream"] }
revm.workspace = true
Expand Down
131 changes: 131 additions & 0 deletions trin-execution/src/trie_walker/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use alloy::primitives::U256;
use rand::{thread_rng, Rng};

use crate::utils::nibbles_to_right_padded_b256;

#[derive(Debug, Clone)]
pub struct Filter {
start: U256,
end: U256,
}

impl Filter {
/// Create a new filter that includes the whole trie
/// Slice index must be less than slice count or it will panic
pub fn new(slice_index: u16, slice_count: u16) -> Self {
// if slice_count is 0 or 1, we want to include the whole trie
if slice_count == 0 || slice_count == 1 {
return Self {
start: U256::ZERO,
end: U256::MAX,
};
}

assert!(
slice_index < slice_count,
"slice_index must be less than slice_count"
);

let slice_size = U256::MAX / U256::from(slice_count) + U256::from(1);

let start = U256::from(slice_index) * slice_size;
let end = if slice_index == slice_count - 1 {
U256::MAX
} else {
start + slice_size - U256::from(1)
};

Self { start, end }
}

pub fn random(slice_count: u16) -> Self {
// if slice_count is 0 or 1, we want to include the whole trie
if slice_count == 0 || slice_count == 1 {
return Self {
start: U256::ZERO,
end: U256::MAX,
};
}

Self::new(thread_rng().gen_range(0..slice_count), slice_count)
}

/// Check if a path is included in the filter
pub fn contains(&self, path: &[u8]) -> bool {
// we need to use partial prefixes to not artificially exclude paths that are not exactly
// the same length as the filter
let shift_amount = 256 - path.len() * 4;
let partial_start_prefix = (self.start >> shift_amount) << shift_amount;
let partial_end_prefix = (self.end >> shift_amount) << shift_amount;
let path = nibbles_to_right_padded_b256(path);

(partial_start_prefix..=partial_end_prefix).contains(&path.into())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn random() {
let filter = Filter::random(0);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX);

let filter = Filter::random(1);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX);
}

#[test]
fn contains() {
let filter = Filter {
start: nibbles_to_right_padded_b256(&[0x1, 0x5, 0x5]).into(),
end: nibbles_to_right_padded_b256(&[0x3]).into(),
};

assert!(!filter.contains(&[0x0]));
assert!(filter.contains(&[0x1]));
assert!(filter.contains(&[0x1, 0x5]));
assert!(filter.contains(&[0x1, 0x5, 0x5]));
assert!(!filter.contains(&[0x1, 0x5, 0x4]));
assert!(filter.contains(&[0x2]));
assert!(filter.contains(&[0x3]));
assert!(!filter.contains(&[0x3, 0x0, 0x1]));
assert!(!filter.contains(&[0x4]));
}

#[test]
fn new() {
let filter = Filter::new(0, 1);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX);

let filter = Filter::new(0, 2);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX / U256::from(2));

let filter = Filter::new(1, 2);
assert_eq!(filter.start, U256::MAX / U256::from(2) + U256::from(1));
assert_eq!(filter.end, U256::MAX);

let filter = Filter::new(0, 3);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX / U256::from(3));

let filter = Filter::new(1, 3);
assert_eq!(filter.start, U256::MAX / U256::from(3) + U256::from(1));
assert_eq!(
filter.end,
U256::MAX / U256::from(3) * U256::from(2) + U256::from(1)
);

let filter = Filter::new(2, 3);
assert_eq!(
filter.start,
U256::MAX / U256::from(3) * U256::from(2) + U256::from(2)
);
assert_eq!(filter.end, U256::MAX);
}
}
19 changes: 17 additions & 2 deletions trin-execution/src/trie_walker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
pub mod db;
pub mod filter;

use std::sync::Arc;

use alloy::primitives::{Bytes, B256};
use anyhow::{anyhow, Ok};
use db::TrieWalkerDb;
use eth_trie::{decode_node, node::Node};
use filter::Filter;

use crate::types::trie_proof::TrieProof;

Expand All @@ -21,10 +23,13 @@ pub struct TrieWalker<DB: TrieWalkerDb> {
is_partial_trie: bool,
trie: Arc<DB>,
stack: Vec<TrieProof>,

/// You can filter what slice of the trie you want to walk
filter: Option<Filter>,
}

impl<DB: TrieWalkerDb> TrieWalker<DB> {
pub fn new(root_hash: B256, trie: Arc<DB>) -> anyhow::Result<Self> {
pub fn new(root_hash: B256, trie: Arc<DB>, filter: Option<Filter>) -> anyhow::Result<Self> {
let root_node_trie = match trie.get(root_hash.as_slice())? {
Some(root_node_trie) => root_node_trie,
None => return Err(anyhow!("Root node not found in the database")),
Expand All @@ -38,6 +43,7 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
is_partial_trie: false,
trie,
stack: vec![root_proof],
filter,
})
}

Expand All @@ -52,6 +58,7 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
is_partial_trie: true,
trie: Arc::new(trie),
stack: vec![],
filter: None,
});
}
};
Expand All @@ -65,6 +72,7 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
is_partial_trie: true,
trie: Arc::new(trie),
stack: vec![root_proof],
filter: None,
})
}

Expand All @@ -74,6 +82,13 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
partial_proof: Vec<Bytes>,
path: Vec<u8>,
) -> anyhow::Result<()> {
// If we have a filter, we only want to include nodes that are in the filter
if let Some(filter) = &self.filter {
if !filter.contains(&path) {
return Ok(());
}
}

// We only need to process hash nodes, because if the node isn't a hash node then none of
// its children is
if let Node::Hash(hash) = node {
Expand Down Expand Up @@ -191,7 +206,7 @@ mod tests {
}

let root_hash = trie.root_hash().unwrap();
let walker = TrieWalker::new(root_hash, trie.db.clone()).unwrap();
let walker = TrieWalker::new(root_hash, trie.db.clone(), None).unwrap();
let mut count = 0;
let mut leaf_count = 0;
for proof in walker {
Expand Down
34 changes: 26 additions & 8 deletions trin-execution/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 {
);
}

let mut raw_address_hash = vec![];
for i in 0..key_path.len() {
nibbles_to_right_padded_b256(key_path)
}

pub fn nibbles_to_right_padded_b256(nibbles: &[u8]) -> B256 {
let mut result = B256::ZERO;
for (i, nibble) in nibbles.iter().enumerate() {
if i % 2 == 0 {
raw_address_hash.push(key_path[i] << 4);
result[i / 2] |= nibble << 4;
} else {
raw_address_hash[i / 2] |= key_path[i];
}
result[i / 2] |= nibble;
};
}
B256::from_slice(&raw_address_hash)
result
}

pub fn address_to_nibble_path(address: Address) -> Vec<u8> {
Expand All @@ -28,10 +32,13 @@ pub fn address_to_nibble_path(address: Address) -> Vec<u8> {

#[cfg(test)]
mod tests {
use alloy::hex::FromHex;
use eth_trie::nibbles::Nibbles as EthNibbles;
use revm_primitives::{keccak256, Address};
use revm_primitives::{keccak256, Address, B256};

use crate::utils::{address_to_nibble_path, full_nibble_path_to_address_hash};
use crate::utils::{
address_to_nibble_path, full_nibble_path_to_address_hash, nibbles_to_right_padded_b256,
};

#[test]
fn test_eth_trie_and_ethportalapi_nibbles() {
Expand All @@ -52,4 +59,15 @@ mod tests {
let generated_address_hash = full_nibble_path_to_address_hash(&path);
assert_eq!(address_hash, generated_address_hash);
}

#[test]
fn test_partial_nibble_path_to_right_padded_b256() {
let partial_nibble_path = vec![0xf, 0xf, 0x0, 0x1, 0x0, 0x2, 0x0, 0x3];
let partial_path = nibbles_to_right_padded_b256(&partial_nibble_path);
assert_eq!(
partial_path,
B256::from_hex("0xff01020300000000000000000000000000000000000000000000000000000000")
.unwrap()
);
}
}