diff --git a/waku/v2/protocol/metadata/waku_metadata.go b/waku/v2/protocol/metadata/waku_metadata.go index aae9db62e..71c0dcc09 100644 --- a/waku/v2/protocol/metadata/waku_metadata.go +++ b/waku/v2/protocol/metadata/waku_metadata.go @@ -4,7 +4,6 @@ import ( "context" "errors" "math" - "strings" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/libp2p/go-libp2p/core/host" @@ -57,6 +56,7 @@ func (wakuM *WakuMetadata) SetHost(h host.Host) { func (wakuM *WakuMetadata) Start(ctx context.Context) error { if wakuM.clusterID == 0 { wakuM.log.Warn("no clusterID is specified. Protocol will not be initialized") + return nil } ctx, cancel := context.WithCancel(ctx) @@ -135,16 +135,21 @@ func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protoc stream.Close() if response.ClusterId == nil { - return nil, nil // Node is not using sharding + return nil, errors.New("node did not provide a waku clusterid") } - result := &protocol.RelayShards{} - result.ClusterID = uint16(*response.ClusterId) + rClusterID := uint16(*response.ClusterId) + var rShardIDs []uint16 for _, i := range response.Shards { - result.ShardIDs = append(result.ShardIDs, uint16(i)) + rShardIDs = append(rShardIDs, uint16(i)) } - return result, nil + rs, err := protocol.NewRelayShards(rClusterID, rShardIDs...) + if err != nil { + return nil, err + } + + return &rs, nil } func (wakuM *WakuMetadata) onRequest(ctx context.Context) func(network.Stream) { @@ -209,6 +214,15 @@ func (wakuM *WakuMetadata) ListenClose(n network.Network, m multiaddr.Multiaddr) // Do nothing } +func (wakuM *WakuMetadata) disconnectPeer(peerID peer.ID, reason error) { + logger := wakuM.log.With(logging.HostID("peerID", peerID)) + logger.Error("disconnecting from peer", zap.Error(reason)) + wakuM.h.Peerstore().RemovePeer(peerID) + if err := wakuM.h.Network().ClosePeer(peerID); err != nil { + logger.Error("could not disconnect from peer", zap.Error(err)) + } +} + // Connected is called when a connection is opened func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) { go func() { @@ -219,30 +233,14 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) { peerID := cc.RemotePeer() - logger := wakuM.log.With(logging.HostID("peerID", peerID)) - - shouldDisconnect := true shard, err := wakuM.Request(wakuM.ctx, peerID) - if err == nil { - if shard == nil { - err = errors.New("no shard reported") - } else if shard.ClusterID != wakuM.clusterID { - err = errors.New("different clusterID reported") - } - } else { - // Only disconnect from peers if they support the protocol - // TODO: open a PR in go-libp2p to create a var with this error to not have to compare strings but use errors.Is instead - if strings.Contains(err.Error(), "protocols not supported") { - shouldDisconnect = false - } + if err != nil { + wakuM.disconnectPeer(peerID, err) + return } - if shouldDisconnect && err != nil { - logger.Error("disconnecting from peer", zap.Error(err)) - wakuM.h.Peerstore().RemovePeer(peerID) - if err := wakuM.h.Network().ClosePeer(peerID); err != nil { - logger.Error("could not disconnect from peer", zap.Error(err)) - } + if shard.ClusterID != wakuM.clusterID { + wakuM.disconnectPeer(peerID, errors.New("different clusterID reported")) } }() } diff --git a/waku/v2/protocol/metadata/waku_metadata_test.go b/waku/v2/protocol/metadata/waku_metadata_test.go index 58b462b43..1ccb90a3e 100644 --- a/waku/v2/protocol/metadata/waku_metadata_test.go +++ b/waku/v2/protocol/metadata/waku_metadata_test.go @@ -3,12 +3,15 @@ package metadata import ( "context" "crypto/rand" + "errors" "testing" "time" gcrypto "github.com/ethereum/go-ethereum/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + libp2pProtocol "github.com/libp2p/go-libp2p/core/protocol" + "github.com/multiformats/go-multistream" "github.com/stretchr/testify/require" "github.com/waku-org/go-waku/tests" "github.com/waku-org/go-waku/waku/v2/protocol" @@ -42,6 +45,11 @@ func createWakuMetadata(t *testing.T, rs *protocol.RelayShards) *WakuMetadata { return m1 } +func isProtocolNotSupported(err error) bool { + notSupportedErr := multistream.ErrNotSupported[libp2pProtocol.ID]{} + return errors.Is(err, notSupportedErr) +} + func TestWakuMetadataRequest(t *testing.T) { testShard16 := uint16(16) @@ -79,11 +87,9 @@ func TestWakuMetadataRequest(t *testing.T) { require.Equal(t, testShard16, result.ClusterID) require.ElementsMatch(t, rs16_2.ShardIDs, result.ShardIDs) - // Query a peer not subscribed to a shard - result, err = m16_1.Request(context.Background(), m_noRS.h.ID()) - require.NoError(t, err) - require.Equal(t, uint16(0), result.ClusterID) - require.Len(t, result.ShardIDs, 0) + // Query a peer not subscribed to any shard + _, err = m16_1.Request(context.Background(), m_noRS.h.ID()) + require.True(t, isProtocolNotSupported(err)) } func TestNoNetwork(t *testing.T) { @@ -93,7 +99,7 @@ func TestNoNetwork(t *testing.T) { require.NoError(t, err) m1 := createWakuMetadata(t, &rs1) - // host2 does not support metadata protocol + // host2 does not support metadata protocol, so it should be dropped port, err := tests.FindFreePort(t, "", 5) require.NoError(t, err) host2, err := tests.MakeHost(context.Background(), port, rand.Reader) @@ -106,12 +112,10 @@ func TestNoNetwork(t *testing.T) { time.Sleep(2 * time.Second) // Verifying peer connections - require.Len(t, m1.h.Network().Peers(), 1) - require.Len(t, host2.Network().Peers(), 1) + require.Len(t, m1.h.Network().Peers(), 0) + require.Len(t, host2.Network().Peers(), 0) } -// go test -timeout 300s -run TestDropConnectionOnDiffNetworks github.com/waku-org/go-waku/waku/v2/protocol/metadata -count 1 -v - func TestDropConnectionOnDiffNetworks(t *testing.T) { cluster1 := uint16(1) cluster2 := uint16(2)