diff --git a/kad/triert/config.go b/kad/triert/config.go index 967702b..90a2da7 100644 --- a/kad/triert/config.go +++ b/kad/triert/config.go @@ -6,14 +6,20 @@ import ( // Config holds configuration options for a TrieRT. type Config[K kad.Key[K], N kad.NodeID[K]] struct { - // KeyFilter defines the filter that is applied before a key is added to the table. + // KeyFilter defines the filter that is applied before a key is added to the + // table. KeyFilter is applied before NodeFilter. // If nil, no filter is applied. KeyFilter KeyFilterFunc[K, N] + // NodeFilter defines the filter that is applied before a node is added to + // the table. NodeFilter is applied after KeyFilter. + // If nil, no filter is applied. + NodeFilter NodeFilter[K, N] } // DefaultConfig returns a default configuration for a TrieRT. func DefaultConfig[K kad.Key[K], N kad.NodeID[K]]() *Config[K, N] { return &Config[K, N]{ - KeyFilter: nil, + KeyFilter: nil, + NodeFilter: nil, } } diff --git a/kad/triert/filter.go b/kad/triert/filter.go index 7f4202d..0e63081 100644 --- a/kad/triert/filter.go +++ b/kad/triert/filter.go @@ -11,3 +11,16 @@ func BucketLimit20[K kad.Key[K], N kad.NodeID[K]](rt *TrieRT[K, N], kk K) bool { cpl := rt.Cpl(kk) return rt.CplSize(cpl) < 20 } + +// NodeFilter provides a stateful way to filter nodes before they are added to +// the table. The filter is applied after the key filter. +type NodeFilter[K kad.Key[K], N kad.NodeID[K]] interface { + // TryAdd is called when a node is added to the table. Return true to allow + // the node to be added. Return false to prevent the node from being added + // to the table. When updating its state, the NodeFilter considers that the + // node has been added to the table if TryAdd returns true. + TryAdd(rt *TrieRT[K, N], node N) bool + // Remove is called when a node is removed from the table, allowing the + // filter to update its state. + Remove(node N) +} diff --git a/kad/triert/table.go b/kad/triert/table.go index 32a3a1e..ab94a0f 100644 --- a/kad/triert/table.go +++ b/kad/triert/table.go @@ -14,8 +14,9 @@ import ( // TrieRT is a routing table backed by a XOR Trie which offers good scalablity and performance // for large networks. type TrieRT[K kad.Key[K], N kad.NodeID[K]] struct { - self K - keyFilter KeyFilterFunc[K, N] + self K + keyFilter KeyFilterFunc[K, N] + nodeFilter NodeFilter[K, N] mu sync.Mutex // held to synchronise mutations to the trie trie atomic.Value // holds a *trie.Trie[K, N] @@ -43,6 +44,7 @@ func (rt *TrieRT[K, N]) apply(cfg *Config[K, N]) error { } rt.keyFilter = cfg.KeyFilter + rt.nodeFilter = cfg.NodeFilter return nil } @@ -58,6 +60,9 @@ func (rt *TrieRT[K, N]) AddNode(node N) bool { if rt.keyFilter != nil && !rt.keyFilter(rt, kk) { return false } + if rt.nodeFilter != nil && !rt.nodeFilter.TryAdd(rt, node) { + return false + } rt.mu.Lock() defer rt.mu.Unlock() @@ -81,6 +86,10 @@ func (rt *TrieRT[K, N]) RemoveKey(kk K) bool { return false } rt.trie.Store(next) + if rt.nodeFilter != nil { + _, node := trie.Find(this, kk) + rt.nodeFilter.Remove(node) + } return true } diff --git a/kad/triert/table_test.go b/kad/triert/table_test.go index e6d7523..caf1f80 100644 --- a/kad/triert/table_test.go +++ b/kad/triert/table_test.go @@ -314,6 +314,64 @@ func TestKeyFilter(t *testing.T) { require.Equal(t, want, got) } +var _ NodeFilter[kadtest.Key32, node[kadtest.Key32]] = (*nodeFilter)(nil) + +type nodeFilter struct { + state []node[kadtest.Key32] +} + +func (f *nodeFilter) TryAdd(rt *TrieRT[kadtest.Key32, node[kadtest.Key32]], + n node[kadtest.Key32]) bool { + if n == node2 { + return false + } + f.state = append(f.state, n) + return true +} + +func (f *nodeFilter) Remove(n node[kadtest.Key32]) { + for i, nn := range f.state { + if nn == n { + f.state = append(f.state[:i], f.state[i+1:]...) + return + } + } +} + +func TestPeerFilter(t *testing.T) { + cfg := DefaultConfig[kadtest.Key32, node[kadtest.Key32]]() + filter := &nodeFilter{} + cfg.NodeFilter = filter + rt, err := New[kadtest.Key32](node0, cfg) + require.NoError(t, err) + + // can't add node2 + success := rt.AddNode(node2) + require.NoError(t, err) + require.False(t, success) + require.Empty(t, filter.state) + + got, found := rt.GetNode(key2) + require.False(t, found) + require.Zero(t, got) + + // can add other node + success = rt.AddNode(node1) + require.NoError(t, err) + require.True(t, success) + require.Equal(t, []node[kadtest.Key32]{node1}, filter.state) + + want := node1 + got, found = rt.GetNode(key1) + require.True(t, found) + require.Equal(t, want, got) + + // remove node1 + success = rt.RemoveKey(key1) + require.True(t, success) + require.Empty(t, filter.state) +} + func TestTableConcurrentReadWrite(t *testing.T) { nodes := make([]*kadtest.ID[kadtest.Key32], 5000) for i := range nodes {