Skip to content

Commit

Permalink
chore: MPT cleanups (#56)
Browse files Browse the repository at this point in the history
* MPT cleanups

* avoid clone when calling hash

* fix clippy
  • Loading branch information
Wollac authored Nov 21, 2023
1 parent 71135b7 commit 59c23fd
Showing 1 changed file with 100 additions and 83 deletions.
183 changes: 100 additions & 83 deletions primitives/src/trie/mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,19 +271,6 @@ impl MptNode {
&self.data
}

/// Computes and returns the 256-bit hash of the node.
///
/// This method provides a unique identifier for the node based on its content.
pub fn hash(&self) -> B256 {
match self.data {
MptNodeData::Null => EMPTY_ROOT,
_ => match self.reference() {
MptNodeReference::Digest(digest) => digest,
MptNodeReference::Bytes(bytes) => keccak(bytes).into(),
},
}
}

/// Retrieves the [MptNodeReference] reference of the node when it's referenced inside
/// another node.
///
Expand All @@ -296,14 +283,33 @@ impl MptNode {
.clone()
}

/// Computes and returns the 256-bit hash of the node.
///
/// This method provides a unique identifier for the node based on its content.
pub fn hash(&self) -> B256 {
match self.data {
MptNodeData::Null => EMPTY_ROOT,
_ => match self
.cached_reference
.borrow_mut()
.get_or_insert_with(|| self.calc_reference())
{
MptNodeReference::Digest(digest) => *digest,
MptNodeReference::Bytes(bytes) => keccak(bytes).into(),
},
}
}

/// Encodes the [MptNodeReference] of this node into the `out` buffer.
fn reference_encode(&self, out: &mut dyn alloy_rlp::BufMut) {
match self
.cached_reference
.borrow_mut()
.get_or_insert_with(|| self.calc_reference())
{
// if the reference is an RLP-encoded byte slice, copy it directly
MptNodeReference::Bytes(bytes) => out.put_slice(bytes),
// if the reference is a digest, RLP-encode it with its fixed known length
MptNodeReference::Digest(digest) => {
out.put_u8(alloy_rlp::EMPTY_STRING_CODE + 32);
out.put_slice(digest.as_slice());
Expand Down Expand Up @@ -360,22 +366,7 @@ impl MptNode {
pub fn nibs(&self) -> Vec<u8> {
match &self.data {
MptNodeData::Null | MptNodeData::Branch(_) | MptNodeData::Digest(_) => vec![],
MptNodeData::Leaf(prefix, _) | MptNodeData::Extension(prefix, _) => {
let extension = prefix[0];
// the first bit of the first nibble denotes the parity
let is_odd = extension & (1 << 4) != 0;

let mut result = Vec::with_capacity(2 * prefix.len() - 1);
// for odd lengths, the second nibble contains the first element
if is_odd {
result.push(extension & 0xf);
}
for nib in &prefix[1..] {
result.push(nib >> 4);
result.push(nib & 0xf);
}
result
}
MptNodeData::Leaf(prefix, _) | MptNodeData::Extension(prefix, _) => prefix_nibs(prefix),
}
}

Expand Down Expand Up @@ -403,30 +394,27 @@ impl MptNode {
match &self.data {
MptNodeData::Null => Ok(None),
MptNodeData::Branch(nodes) => {
if key_nibs.is_empty() {
Ok(None)
if let Some((i, tail)) = key_nibs.split_first() {
match nodes[*i as usize] {
Some(ref node) => node.get_internal(tail),
None => Ok(None),
}
} else {
nodes
.get(key_nibs[0] as usize)
.unwrap()
.as_ref()
.map_or(Ok(None), |n| n.get_internal(&key_nibs[1..]))
Ok(None)
}
}
MptNodeData::Leaf(_, value) => {
if self.nibs() == key_nibs {
MptNodeData::Leaf(prefix, value) => {
if prefix_nibs(prefix) == key_nibs {
Ok(Some(value))
} else {
Ok(None)
}
}
MptNodeData::Extension(_, node) => {
let ext_nibs = self.nibs();
let ext_len = ext_nibs.len();
if key_nibs[..ext_len] != ext_nibs {
Ok(None)
MptNodeData::Extension(prefix, node) => {
if let Some(tail) = key_nibs.strip_prefix(prefix_nibs(prefix).as_slice()) {
node.get_internal(tail)
} else {
node.get_internal(&key_nibs[ext_len..])
Ok(None)
}
}
MptNodeData::Digest(digest) => Err(Error::NodeNotResolved(*digest)),
Expand All @@ -442,25 +430,25 @@ impl MptNode {
}

fn delete_internal(&mut self, key_nibs: &[u8]) -> Result<bool, Error> {
let mut self_nibs = self.nibs();
match &mut self.data {
MptNodeData::Null => return Ok(false),
MptNodeData::Branch(children) => {
if key_nibs.is_empty() {
return Err(Error::ValueInBranch);
}
let child = children.get_mut(key_nibs[0] as usize).unwrap();
match child {
Some(node) => {
if !node.delete_internal(&key_nibs[1..])? {
return Ok(false);
}
// if the node is now empty, remove it
if node.is_empty() {
*child = None;
if let Some((i, tail)) = key_nibs.split_first() {
let child = &mut children[*i as usize];
match child {
Some(node) => {
if !node.delete_internal(tail)? {
return Ok(false);
}
// if the node is now empty, remove it
if node.is_empty() {
*child = None;
}
}
None => return Ok(false),
}
None => return Ok(false),
} else {
return Err(Error::ValueInBranch);
}

let mut remaining = children.iter_mut().enumerate().filter(|(_, n)| n.is_some());
Expand Down Expand Up @@ -501,18 +489,19 @@ impl MptNode {
}
}
}
MptNodeData::Leaf(_, _) => {
if self_nibs != key_nibs {
MptNodeData::Leaf(prefix, _) => {
if prefix_nibs(prefix) != key_nibs {
return Ok(false);
}
self.data = MptNodeData::Null;
}
MptNodeData::Extension(_, child) => {
let ext_len = self_nibs.len();
if key_nibs[..ext_len] != self_nibs {
return Ok(false);
}
if !child.delete_internal(&key_nibs[ext_len..])? {
MptNodeData::Extension(prefix, child) => {
let mut self_nibs = prefix_nibs(prefix);
if let Some(tail) = key_nibs.strip_prefix(self_nibs.as_slice()) {
if !child.delete_internal(tail)? {
return Ok(false);
}
} else {
return Ok(false);
}

Expand Down Expand Up @@ -568,31 +557,32 @@ impl MptNode {
}

fn insert_internal(&mut self, key_nibs: &[u8], value: Vec<u8>) -> Result<bool, Error> {
let self_nibs = self.nibs();
match &mut self.data {
MptNodeData::Null => {
self.data = MptNodeData::Leaf(to_encoded_path(key_nibs, true), value);
}
MptNodeData::Branch(children) => {
if key_nibs.is_empty() {
return Err(Error::ValueInBranch);
}
let child = children.get_mut(key_nibs[0] as usize).unwrap();
match child {
Some(node) => {
if !node.insert_internal(&key_nibs[1..], value)? {
return Ok(false);
if let Some((i, tail)) = key_nibs.split_first() {
let child = &mut children[*i as usize];
match child {
Some(node) => {
if !node.insert_internal(tail, value)? {
return Ok(false);
}
}
// if the corresponding child is empty, insert a new leaf
None => {
*child = Some(Box::new(
MptNodeData::Leaf(to_encoded_path(tail, true), value).into(),
));
}
}
// if the corresponding child is empty, insert a new leaf
None => {
*child = Some(Box::new(
MptNodeData::Leaf(to_encoded_path(&key_nibs[1..], true), value).into(),
));
}
} else {
return Err(Error::ValueInBranch);
}
}
MptNodeData::Leaf(_, old_value) => {
MptNodeData::Leaf(prefix, old_value) => {
let self_nibs = prefix_nibs(prefix);
let common_len = lcp(&self_nibs, key_nibs);
if common_len == self_nibs.len() && common_len == key_nibs.len() {
// if self_nibs == key_nibs, update the value if it is different
Expand Down Expand Up @@ -631,7 +621,8 @@ impl MptNode {
}
}
}
MptNodeData::Extension(_, existing_child) => {
MptNodeData::Extension(prefix, existing_child) => {
let self_nibs = prefix_nibs(prefix);
let common_len = lcp(&self_nibs, key_nibs);
if common_len == self_nibs.len() {
// traverse down for update
Expand Down Expand Up @@ -806,6 +797,23 @@ fn lcp(a: &[u8], b: &[u8]) -> usize {
cmp::min(a.len(), b.len())
}

fn prefix_nibs(prefix: &[u8]) -> Vec<u8> {
let (extension, tail) = prefix.split_first().unwrap();
// the first bit of the first nibble denotes the parity
let is_odd = extension & (1 << 4) != 0;

let mut result = Vec::with_capacity(2 * tail.len() + is_odd as usize);
// for odd lengths, the second nibble contains the first element
if is_odd {
result.push(extension & 0xf);
}
for nib in tail {
result.push(nib >> 4);
result.push(nib & 0xf);
}
result
}

#[cfg(test)]
mod tests {
use hex_literal::hex;
Expand Down Expand Up @@ -880,6 +888,15 @@ mod tests {
assert_eq!(trie.hash(), decoded.hash());
}

#[test]
pub fn test_empty_key() {
let mut trie = MptNode::default();

trie.insert(&[], b"empty".to_vec()).unwrap();
assert_eq!(trie.get(&[]).unwrap(), Some(b"empty".as_ref()));
assert!(trie.delete(&[]).unwrap());
}

#[test]
pub fn test_clear() {
let mut trie = MptNode::default();
Expand Down

0 comments on commit 59c23fd

Please sign in to comment.