Skip to content

Commit

Permalink
Ensure protocol naming matches latest Starknet p2p spec
Browse files Browse the repository at this point in the history
The change ensures that DHT protocol follow latest Starknet
specification.
  • Loading branch information
wojciechos committed Dec 16, 2024
1 parent 4ff174d commit fee852d
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 23 deletions.
16 changes: 8 additions & 8 deletions p2p/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
}
}

p2pdht, err := makeDHT(p2phost, peersAddrInfoS)
p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork.L2ChainID)
if err != nil {
return nil, err
}
Expand All @@ -159,9 +159,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
return s, nil
}

func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) {
func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, chainID string) (*dht.IpfsDHT, error) {
return dht.New(context.Background(), p2phost,
dht.ProtocolPrefix(starknet.Prefix),
dht.ProtocolPrefix(starknet.DHTPrefixPID(chainID)),
dht.BootstrapPeers(addrInfos...),
dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod),
dht.Mode(dht.ModeServer),
Expand Down Expand Up @@ -249,11 +249,11 @@ func (s *Service) Run(ctx context.Context) error {
}

func (s *Service) setProtocolHandlers() {
s.SetProtocolHandler(starknet.HeadersPID(), s.handler.HeadersHandler)
s.SetProtocolHandler(starknet.EventsPID(), s.handler.EventsHandler)
s.SetProtocolHandler(starknet.TransactionsPID(), s.handler.TransactionsHandler)
s.SetProtocolHandler(starknet.ClassesPID(), s.handler.ClassesHandler)
s.SetProtocolHandler(starknet.StateDiffPID(), s.handler.StateDiffHandler)
s.SetProtocolHandler(starknet.HeadersPID(s.network.L2ChainID), s.handler.HeadersHandler)
s.SetProtocolHandler(starknet.EventsPID(s.network.L2ChainID), s.handler.EventsHandler)
s.SetProtocolHandler(starknet.TransactionsPID(s.network.L2ChainID), s.handler.TransactionsHandler)
s.SetProtocolHandler(starknet.ClassesPID(s.network.L2ChainID), s.handler.ClassesHandler)
s.SetProtocolHandler(starknet.StateDiffPID(s.network.L2ChainID), s.handler.StateDiffHandler)
}

func (s *Service) callAndLogErr(f func() error, msg string) {
Expand Down
36 changes: 36 additions & 0 deletions p2p/p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"github.com/NethermindEth/juno/p2p"
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) {
)
require.NoError(t, err)
}

func TestMakeDHTProtocolName(t *testing.T) {
net, err := mocknet.FullMeshLinked(1)
require.NoError(t, err)
testHost := net.Hosts()[0]

testCases := []struct {
name string
network *utils.Network
expected string
}{
{
name: "sepolia network",
network: &utils.Sepolia,
expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0",
},
{
name: "mainnet network",
network: &utils.Mainnet,
expected: "/starknet/SN_MAIN/sync/kad/1.0.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dht, err := p2p.MakeDHT(testHost, nil, tc.network.L2ChainID)
require.NoError(t, err)

protocols := dht.Host().Mux().Protocols()
assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols)
})
}
}
12 changes: 7 additions & 5 deletions p2p/starknet/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,24 @@ func (c *Client) RequestBlockHeaders(
ctx context.Context, req *spec.BlockHeadersRequest,
) (iter.Seq[*spec.BlockHeadersResponse], error) {
return requestAndReceiveStream[*spec.BlockHeadersRequest, *spec.BlockHeadersResponse](
ctx, c.newStream, HeadersPID(), req, c.log)
ctx, c.newStream, HeadersPID(c.network.L2ChainID), req, c.log)
}

func (c *Client) RequestEvents(ctx context.Context, req *spec.EventsRequest) (iter.Seq[*spec.EventsResponse], error) {
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log)
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(c.network.L2ChainID), req, c.log)
}

func (c *Client) RequestClasses(ctx context.Context, req *spec.ClassesRequest) (iter.Seq[*spec.ClassesResponse], error) {
return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log)
return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(c.network.L2ChainID), req, c.log)
}

func (c *Client) RequestStateDiffs(ctx context.Context, req *spec.StateDiffsRequest) (iter.Seq[*spec.StateDiffsResponse], error) {
return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log)
return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](
ctx, c.newStream, StateDiffPID(c.network.L2ChainID), req, c.log,
)
}

func (c *Client) RequestTransactions(ctx context.Context, req *spec.TransactionsRequest) (iter.Seq[*spec.TransactionsResponse], error) {
return requestAndReceiveStream[*spec.TransactionsRequest, *spec.TransactionsResponse](
ctx, c.newStream, TransactionsPID(), req, c.log)
ctx, c.newStream, TransactionsPID(c.network.L2ChainID), req, c.log)
}
25 changes: 15 additions & 10 deletions p2p/starknet/ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,27 @@ import (

const Prefix = "/starknet"

func HeadersPID() protocol.ID {
return Prefix + "/headers/0.1.0-rc.0"
func HeadersPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/headers/0.1.0-rc.0")
}

func EventsPID() protocol.ID {
return Prefix + "/events/0.1.0-rc.0"
func EventsPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/events/0.1.0-rc.0")
}

func TransactionsPID() protocol.ID {
return Prefix + "/transactions/0.1.0-rc.0"
func TransactionsPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/transactions/0.1.0-rc.0")
}

func ClassesPID() protocol.ID {
return Prefix + "/classes/0.1.0-rc.0"
func ClassesPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/classes/0.1.0-rc.0")
}

func StateDiffPID() protocol.ID {
return Prefix + "/state_diffs/0.1.0-rc.0"
func StateDiffPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/state_diffs/0.1.0-rc.0")
}

// DHTPrefixPID returns the protocol ID used as the DHT protocol prefix for a specific chain
func DHTPrefixPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync")
}
66 changes: 66 additions & 0 deletions p2p/starknet/ids_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package starknet

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestProtocolIDs(t *testing.T) {
testCases := []struct {
name string
chainID string
pidFunc func(string) string
expected string
}{
{
name: "HeadersPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(HeadersPID(c)) },
expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0",
},
{
name: "EventsPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(EventsPID(c)) },
expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0",
},
{
name: "TransactionsPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(TransactionsPID(c)) },
expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0",
},
{
name: "ClassesPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(ClassesPID(c)) },
expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0",
},
{
name: "StateDiffPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(StateDiffPID(c)) },
expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0",
},
{
name: "DHTPrefixPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(DHTPrefixPID(c)) },
expected: "/starknet/SN_MAIN/sync",
},
{
name: "HeadersPID with SN_SEPOLIA",
chainID: "SN_SEPOLIA",
pidFunc: func(c string) string { return string(HeadersPID(c)) },
expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := tc.pidFunc(tc.chainID)
assert.Equal(t, tc.expected, result)
})
}
}

0 comments on commit fee852d

Please sign in to comment.