Skip to content

Commit

Permalink
feat: add Filter to TrieWalker
Browse files Browse the repository at this point in the history
  • Loading branch information
KolbyML committed Dec 5, 2024
1 parent 3db8488 commit 84dba35
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 16 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions portal-bridge/src/bridge/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use trin_execution::{
import::StateImporter,
utils::{download_with_progress, percentage_from_address_hash},
},
trie_walker::TrieWalker,
trie_walker::{filter::Filter, TrieWalker},
types::{block_to_trace::BlockToTrace, trie_proof::TrieProof},
utils::full_nibble_path_to_address_hash,
};
Expand Down Expand Up @@ -360,7 +360,7 @@ impl StateBridge {

let root_hash = evm_db.trie.lock().root_hash()?;
let mut content_idx = 0;
let state_walker = TrieWalker::new(root_hash, evm_db.trie.lock().db.clone())?;
let state_walker = TrieWalker::new(root_hash, evm_db.trie.lock().db.clone(), None)?;
for account_proof in state_walker {
// gossip the account
self.gossip_account(&account_proof, block_hash, content_idx)
Expand Down Expand Up @@ -426,7 +426,7 @@ impl StateBridge {
let account_db = AccountDB::new(address_hash, evm_db.db.clone());
let trie = EthTrie::from(Arc::new(account_db), account.storage_root)?.db;

let storage_walker = TrieWalker::new(account.storage_root, trie)?;
let storage_walker = TrieWalker::new(account.storage_root, trie, None)?;
for storage_proof in storage_walker {
self.gossip_storage(
&account_proof,
Expand Down
1 change: 1 addition & 0 deletions trin-execution/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jsonrpsee = { workspace = true, features = ["async-client", "client", "macros",
lazy_static.workspace = true
parking_lot.workspace = true
prometheus_exporter.workspace = true
rand.workspace = true
rayon = "1.10.0"
reqwest = { workspace = true, features = ["stream"] }
revm.workspace = true
Expand Down
72 changes: 72 additions & 0 deletions trin-execution/src/trie_walker/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use alloy::primitives::{B256, U256};
use rand::{thread_rng, Rng};

use crate::utils::partial_nibble_path_to_right_padded_b256;

#[derive(Debug, Clone)]
pub struct Filter {
start_prefix: B256,
end_prefix: B256,
}

impl Filter {
pub fn new_random_filter(slice_count: u16) -> Self {
// if slice_count is 0 or 1, we want to include the whole trie
if slice_count == 0 || slice_count == 1 {
return Self {
start_prefix: B256::ZERO,
end_prefix: B256::from(U256::MAX),
};
}

let slice_size = U256::MAX / U256::from(slice_count);
let random_slice_index = thread_rng().gen_range(0..slice_count);

let start_prefix = U256::from(random_slice_index) * slice_size;
let end_prefix = if random_slice_index == slice_count - 1 {
U256::MAX
} else {
start_prefix + slice_size - U256::from(1)
};

Self {
start_prefix: B256::from(start_prefix),
end_prefix: B256::from(end_prefix),
}
}

pub fn is_included(&self, path: &[u8]) -> bool {
let path = partial_nibble_path_to_right_padded_b256(path);
(self.start_prefix..=self.end_prefix).contains(&path)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_new_random_filter() {
let filter = Filter::new_random_filter(0);
assert_eq!(filter.start_prefix, B256::ZERO);
assert_eq!(filter.end_prefix, B256::from(U256::MAX));

let filter = Filter::new_random_filter(1);
assert_eq!(filter.start_prefix, B256::ZERO);
assert_eq!(filter.end_prefix, B256::from(U256::MAX));
}

#[test]
fn test_is_included() {
let filter = Filter {
start_prefix: partial_nibble_path_to_right_padded_b256(&[0x1]),
end_prefix: partial_nibble_path_to_right_padded_b256(&[0x3]),
};

assert!(!filter.is_included(&[0x00]));
assert!(filter.is_included(&[0x01]));
assert!(filter.is_included(&[0x02]));
assert!(filter.is_included(&[0x03]));
assert!(!filter.is_included(&[0x04]));
}
}
19 changes: 17 additions & 2 deletions trin-execution/src/trie_walker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
pub mod db;
pub mod filter;

use std::sync::Arc;

use alloy::primitives::{Bytes, B256};
use anyhow::{anyhow, Ok};
use db::TrieWalkerDb;
use eth_trie::{decode_node, node::Node};
use filter::Filter;

use crate::types::trie_proof::TrieProof;

Expand All @@ -21,10 +23,13 @@ pub struct TrieWalker<DB: TrieWalkerDb> {
is_partial_trie: bool,
trie: Arc<DB>,
stack: Vec<TrieProof>,

/// You can filter what slice of the trie you want to walk
filter: Option<Filter>,
}

impl<DB: TrieWalkerDb> TrieWalker<DB> {
pub fn new(root_hash: B256, trie: Arc<DB>) -> anyhow::Result<Self> {
pub fn new(root_hash: B256, trie: Arc<DB>, filter: Option<Filter>) -> anyhow::Result<Self> {
let root_node_trie = match trie.get(root_hash.as_slice())? {
Some(root_node_trie) => root_node_trie,
None => return Err(anyhow!("Root node not found in the database")),
Expand All @@ -38,6 +43,7 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
is_partial_trie: false,
trie,
stack: vec![root_proof],
filter,
})
}

Expand All @@ -52,6 +58,7 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
is_partial_trie: true,
trie: Arc::new(trie),
stack: vec![],
filter: None,
});
}
};
Expand All @@ -65,6 +72,7 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
is_partial_trie: true,
trie: Arc::new(trie),
stack: vec![root_proof],
filter: None,
})
}

Expand All @@ -74,6 +82,13 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
partial_proof: Vec<Bytes>,
path: Vec<u8>,
) -> anyhow::Result<()> {
// If we have a filter, we only want to include nodes that are in the filter
if let Some(filter) = &self.filter {
if !filter.is_included(&path) {
return Ok(());
}
}

// We only need to process hash nodes, because if the node isn't a hash node then none of
// its children is
if let Node::Hash(hash) = node {
Expand Down Expand Up @@ -191,7 +206,7 @@ mod tests {
}

let root_hash = trie.root_hash().unwrap();
let walker = TrieWalker::new(root_hash, trie.db.clone()).unwrap();
let walker = TrieWalker::new(root_hash, trie.db.clone(), None).unwrap();
let mut count = 0;
let mut leaf_count = 0;
for proof in walker {
Expand Down
45 changes: 34 additions & 11 deletions trin-execution/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
use alloy::primitives::{keccak256, Address, B256};

fn compress_nibbles(nibbles: &[u8]) -> Vec<u8> {
let mut compressed_nibbles = vec![];
for i in 0..nibbles.len() {
if i % 2 == 0 {
compressed_nibbles.push(nibbles[i] << 4);
} else {
compressed_nibbles[i / 2] |= nibbles[i];
}
}
compressed_nibbles
}

pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 {
if key_path.len() != 64 {
panic!(
Expand All @@ -8,15 +20,11 @@ pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 {
);
}

let mut raw_address_hash = vec![];
for i in 0..key_path.len() {
if i % 2 == 0 {
raw_address_hash.push(key_path[i] << 4);
} else {
raw_address_hash[i / 2] |= key_path[i];
}
}
B256::from_slice(&raw_address_hash)
B256::from_slice(&compress_nibbles(key_path))
}

pub fn partial_nibble_path_to_right_padded_b256(partial_nibble_path: &[u8]) -> B256 {
B256::right_padding_from(&compress_nibbles(partial_nibble_path))
}

pub fn address_to_nibble_path(address: Address) -> Vec<u8> {
Expand All @@ -28,10 +36,14 @@ pub fn address_to_nibble_path(address: Address) -> Vec<u8> {

#[cfg(test)]
mod tests {
use alloy::hex::FromHex;
use eth_trie::nibbles::Nibbles as EthNibbles;
use revm_primitives::{keccak256, Address};
use revm_primitives::{keccak256, Address, B256};

use crate::utils::{address_to_nibble_path, full_nibble_path_to_address_hash};
use crate::utils::{
address_to_nibble_path, full_nibble_path_to_address_hash,
partial_nibble_path_to_right_padded_b256,
};

#[test]
fn test_eth_trie_and_ethportalapi_nibbles() {
Expand All @@ -52,4 +64,15 @@ mod tests {
let generated_address_hash = full_nibble_path_to_address_hash(&path);
assert_eq!(address_hash, generated_address_hash);
}

#[test]
fn test_partial_nibble_path_to_right_padded_b256() {
let partial_nibble_path = vec![0xf, 0xf, 0x0, 0x1, 0x0, 0x2, 0x0, 0x3];
let partial_path = partial_nibble_path_to_right_padded_b256(&partial_nibble_path);
assert_eq!(
partial_path,
B256::from_hex("0xff01020300000000000000000000000000000000000000000000000000000000")
.unwrap()
);
}
}

0 comments on commit 84dba35

Please sign in to comment.