Skip to content

Commit

Permalink
feat(rln-relay): ensure execution order for pubsub validators
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-ramos committed Sep 14, 2023
1 parent ab7e45c commit 7beaa3f
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 102 deletions.
2 changes: 1 addition & 1 deletion examples/chat2/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func execute(options Options) {
}

if options.RLNRelay.Enable {
spamHandler := func(message *pb.WakuMessage) error {
spamHandler := func(message *pb.WakuMessage, topic string) error {
return nil
}

Expand Down
5 changes: 2 additions & 3 deletions waku/v2/node/wakunode2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
backoffv4 "github.com/cenkalti/backoff/v4"
golog "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"go.uber.org/zap"

"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -66,13 +65,13 @@ type IdentityCredential = struct {
IDCommitment byte32 `json:"idCommitment"`
}

type SpamHandler = func(message *pb.WakuMessage) error
type SpamHandler = func(message *pb.WakuMessage, topic string) error

type RLNRelay interface {
IdentityCredential() (IdentityCredential, error)
MembershipIndex() uint
AppendRLNProof(msg *pb.WakuMessage, senderEpochTime time.Time) error
Validator(spamHandler SpamHandler) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool
Validator(spamHandler SpamHandler) func(ctx context.Context, message *pb.WakuMessage, topic string) bool
Start(ctx context.Context) error
Stop() error
}
Expand Down
8 changes: 5 additions & 3 deletions waku/v2/node/wakunode2_rln.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"context"
"errors"

pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/waku-org/go-waku/waku/v2/protocol/rln"
"github.com/waku-org/go-waku/waku/v2/protocol/rln/group_manager"
"github.com/waku-org/go-waku/waku/v2/protocol/rln/group_manager/dynamic"
Expand All @@ -29,6 +28,10 @@ func (w *WakuNode) setupRLNRelay() error {
return nil
}

if !w.opts.enableRelay {
return errors.New("rln requires relay")
}

var groupManager group_manager.GroupManager

rlnInstance, rootTracker, err := rln.GetRLNInstanceAndRootTracker(w.opts.rlnTreePath)
Expand Down Expand Up @@ -85,8 +88,7 @@ func (w *WakuNode) setupRLNRelay() error {

w.rlnRelay = rlnRelay

// Adding RLN as a default validator
w.opts.pubsubOpts = append(w.opts.pubsubOpts, pubsub.WithDefaultValidator(rlnRelay.Validator(w.opts.rlnSpamHandler)))
w.Relay().RegisterDefaultValidator(w.rlnRelay.Validator(w.opts.rlnSpamHandler))

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion waku/v2/node/wakuoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type WakuNodeParameters struct {
enableRLN bool
rlnRelayMemIndex *uint
rlnRelayDynamic bool
rlnSpamHandler func(message *pb.WakuMessage) error
rlnSpamHandler func(message *pb.WakuMessage, topic string) error
rlnETHClientAddress string
keystorePath string
keystorePassword string
Expand Down
4 changes: 2 additions & 2 deletions waku/v2/protocol/envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ type Envelope struct {
// as well as generating a hash based on the bytes that compose the message
func NewEnvelope(msg *wpb.WakuMessage, receiverTime int64, pubSubTopic string) *Envelope {
messageHash := msg.Hash(pubSubTopic)
hash := hash.SHA256([]byte(msg.ContentTopic), msg.Payload)
digest := hash.SHA256([]byte(msg.ContentTopic), msg.Payload)
return &Envelope{
msg: msg,
hash: messageHash,
index: &pb.Index{
Digest: hash[:],
Digest: digest[:],
ReceiverTime: receiverTime,
SenderTime: msg.Timestamp,
PubsubTopic: pubSubTopic,
Expand Down
89 changes: 61 additions & 28 deletions waku/v2/protocol/relay/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import (

"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"

pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
proto "google.golang.org/protobuf/proto"

"github.com/waku-org/go-waku/waku/v2/hash"
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
"github.com/waku-org/go-waku/waku/v2/timesource"
"go.uber.org/zap"
proto "google.golang.org/protobuf/proto"
)

func msgHash(pubSubTopic string, msg *pb.WakuMessage) []byte {
Expand All @@ -38,54 +38,60 @@ func msgHash(pubSubTopic string, msg *pb.WakuMessage) []byte {
)
}

const messageWindowDuration = time.Minute * 5
type validatorFn = func(ctx context.Context, msg *pb.WakuMessage, topic string) bool

func withinTimeWindow(t timesource.Timesource, msg *pb.WakuMessage) bool {
if msg.Timestamp == 0 {
return false
}
func (w *WakuRelay) RegisterDefaultValidator(fn validatorFn) {
w.topicValidatorMutex.Lock()
defer w.topicValidatorMutex.Unlock()
w.defaultTopicValidators = append(w.defaultTopicValidators, fn)
}

now := t.Now()
msgTime := time.Unix(0, msg.Timestamp)
func (w *WakuRelay) RegisterTopicValidator(topic string, fn validatorFn) {
w.topicValidatorMutex.Lock()
defer w.topicValidatorMutex.Unlock()

return now.Sub(msgTime).Abs() <= messageWindowDuration
w.topicValidators[topic] = append(w.topicValidators[topic], fn)
}

type validatorFn = func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool
func (w *WakuRelay) RemoveTopicValidator(topic string) {
w.topicValidatorMutex.Lock()
defer w.topicValidatorMutex.Unlock()

func validatorFnBuilder(t timesource.Timesource, topic string, publicKey *ecdsa.PublicKey) (validatorFn, error) {
publicKeyBytes := crypto.FromECDSAPub(publicKey)
delete(w.topicValidators, topic)
}

func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
msg := new(pb.WakuMessage)
err := proto.Unmarshal(message.Data, msg)
if err != nil {
return false
}

if !withinTimeWindow(t, msg) {
return false
w.topicValidatorMutex.RLock()
validators, exists := w.topicValidators[topic]
validators = append(validators, w.defaultTopicValidators...)
w.topicValidatorMutex.RUnlock()

if exists {
for _, v := range validators {
if !v(ctx, msg, topic) {
return false
}
}
}

msgHash := msgHash(topic, msg)
signature := msg.Meta

return secp256k1.VerifySignature(publicKeyBytes, msgHash, signature)
}, nil
return true
}
}

// AddSignedTopicValidator registers a gossipsub validator for a topic which will check that messages Meta field contains a valid ECDSA signature for the specified pubsub topic. This is used as a DoS prevention mechanism
func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.PublicKey) error {
w.log.Info("adding validator to signed topic", zap.String("topic", topic), zap.String("publicKey", hex.EncodeToString(elliptic.Marshal(publicKey.Curve, publicKey.X, publicKey.Y))))

fn, err := validatorFnBuilder(w.timesource, topic, publicKey)
if err != nil {
return err
}
fn := signedTopicBuilder(w.timesource, publicKey)

err = w.pubsub.RegisterTopicValidator(topic, fn)
if err != nil {
return err
}
w.RegisterTopicValidator(topic, fn)

if !w.IsSubscribed(topic) {
w.log.Warn("relay is not subscribed to signed topic", zap.String("topic", topic))
Expand All @@ -94,6 +100,33 @@ func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.Publi
return nil
}

const messageWindowDuration = time.Minute * 5

func withinTimeWindow(t timesource.Timesource, msg *pb.WakuMessage) bool {
if msg.Timestamp == 0 {
return false
}

now := t.Now()
msgTime := time.Unix(0, msg.Timestamp)

return now.Sub(msgTime).Abs() <= messageWindowDuration
}

func signedTopicBuilder(t timesource.Timesource, publicKey *ecdsa.PublicKey) validatorFn {
publicKeyBytes := crypto.FromECDSAPub(publicKey)
return func(ctx context.Context, msg *pb.WakuMessage, topic string) bool {
if !withinTimeWindow(t, msg) {
return false
}

msgHash := msgHash(topic, msg)
signature := msg.Meta

return secp256k1.VerifySignature(publicKeyBytes, msgHash, signature)
}
}

// SignMessage adds an ECDSA signature to a WakuMessage as an opt-in mechanism for DoS prevention
func SignMessage(privKey *ecdsa.PrivateKey, msg *pb.WakuMessage, pubsubTopic string) error {
msgHash := msgHash(pubsubTopic, msg)
Expand Down
31 changes: 6 additions & 25 deletions waku/v2/protocol/relay/validators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ import (
"time"

"github.com/ethereum/go-ethereum/crypto"
pubsub "github.com/libp2p/go-libp2p-pubsub"
pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb"
"github.com/stretchr/testify/require"
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
"google.golang.org/protobuf/proto"
)

type FakeTimesource struct {
Expand Down Expand Up @@ -59,39 +56,23 @@ func TestMsgHash(t *testing.T) {
// expectedSignature, _ := hex.DecodeString("127FA211B2514F0E974A055392946DC1A14052182A6ABEFB8A6CD7C51DA1BF2E40595D28EF1A9488797C297EED3AAC45430005FB3A7F037BDD9FC4BD99F59E63")
// require.True(t, bytes.Equal(expectedSignature, msg.Meta))

msgData, _ := proto.Marshal(msg)

//expectedMessageHash, _ := hex.DecodeString("662F8C20A335F170BD60ABC1F02AD66F0C6A6EE285DA2A53C95259E7937C0AE9")
//messageHash := MsgHash(pubsubTopic, msg)
//require.True(t, bytes.Equal(expectedMessageHash, messageHash))

myValidator, err := validatorFnBuilder(NewFakeTimesource(timestamp), protectedPubSubTopic, &prvKey.PublicKey)
require.NoError(t, err)
result := myValidator(context.Background(), "", &pubsub.Message{
Message: &pubsub_pb.Message{
Data: msgData,
},
})
myValidator := signedTopicBuilder(NewFakeTimesource(timestamp), &prvKey.PublicKey)
result := myValidator(context.Background(), msg, protectedPubSubTopic)
require.True(t, result)

// Exceed 5m window in both directions
now5m1sInPast := timestamp.Add(-5 * time.Minute).Add(-1 * time.Second)
myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInPast), protectedPubSubTopic, &prvKey.PublicKey)
myValidator = signedTopicBuilder(NewFakeTimesource(now5m1sInPast), &prvKey.PublicKey)
require.NoError(t, err)
result = myValidator(context.Background(), "", &pubsub.Message{
Message: &pubsub_pb.Message{
Data: msgData,
},
})
result = myValidator(context.Background(), msg, protectedPubSubTopic)
require.False(t, result)

now5m1sInFuture := timestamp.Add(5 * time.Minute).Add(1 * time.Second)
myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInFuture), protectedPubSubTopic, &prvKey.PublicKey)
require.NoError(t, err)
result = myValidator(context.Background(), "", &pubsub.Message{
Message: &pubsub_pb.Message{
Data: msgData,
},
})
myValidator = signedTopicBuilder(NewFakeTimesource(now5m1sInFuture), &prvKey.PublicKey)
result = myValidator(context.Background(), msg, protectedPubSubTopic)
require.False(t, result)
}
18 changes: 12 additions & 6 deletions waku/v2/protocol/relay/waku_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type WakuRelay struct {

minPeersToPublish int

topicValidatorMutex sync.RWMutex
topicValidators map[string][]validatorFn
defaultTopicValidators []validatorFn

// TODO: convert to concurrent maps
topicsMutex sync.Mutex
wakuRelayTopics map[string]*pubsub.Topic
Expand Down Expand Up @@ -83,6 +87,7 @@ func NewWakuRelay(bcaster Broadcaster, minPeersToPublish int, timesource timesou
w.timesource = timesource
w.wakuRelayTopics = make(map[string]*pubsub.Topic)
w.relaySubs = make(map[string]*pubsub.Subscription)
w.topicValidators = make(map[string][]validatorFn)
w.bcaster = bcaster
w.minPeersToPublish = minPeersToPublish
w.CommonService = waku_proto.NewCommonService()
Expand Down Expand Up @@ -177,12 +182,6 @@ func NewWakuRelay(bcaster Broadcaster, minPeersToPublish int, timesource timesou
pubsub.WithSeenMessagesTTL(2 * time.Minute),
pubsub.WithPeerScore(w.peerScoreParams, w.peerScoreThresholds),
pubsub.WithPeerScoreInspect(w.peerScoreInspector, 6*time.Second),
// TODO: to improve - setup default validator only if no default validator has been set.
pubsub.WithDefaultValidator(func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
msg := new(pb.WakuMessage)
err := proto.Unmarshal(message.Data, msg)
return err == nil
}),
}, opts...)

return w
Expand Down Expand Up @@ -270,6 +269,11 @@ func (w *WakuRelay) upsertTopic(topic string) (*pubsub.Topic, error) {

pubSubTopic, ok := w.wakuRelayTopics[topic]
if !ok { // Joins topic if node hasn't joined yet
err := w.pubsub.RegisterTopicValidator(topic, w.topicValidator(topic))
if err != nil {
return nil, err
}

newTopic, err := w.pubsub.Join(string(topic))
if err != nil {
return nil, err
Expand Down Expand Up @@ -419,6 +423,8 @@ func (w *WakuRelay) Unsubscribe(ctx context.Context, topic string) error {
}
delete(w.wakuRelayTopics, topic)

w.RemoveTopicValidator(topic)

err = w.emitters.EvtRelayUnsubscribed.Emit(EvtRelayUnsubscribed{topic})
if err != nil {
return err
Expand Down
9 changes: 5 additions & 4 deletions waku/v2/protocol/relay/waku_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ func TestGossipsubScore(t *testing.T) {
relay := make([]*WakuRelay, 5)
for i := 0; i < 5; i++ {
hosts[i], relay[i] = createRelayNode(t)
if i == 0 {
// This is a hack to remove the default validator from the list of default options
relay[i].opts = relay[i].opts[:len(relay[i].opts)-1]
}
err := relay[i].Start(context.Background())
require.NoError(t, err)
}
Expand Down Expand Up @@ -119,6 +115,11 @@ func TestGossipsubScore(t *testing.T) {
// We obtain the go-libp2p topic directly because we normally can't publish anything other than WakuMessages
pubsubTopic, err := relay[0].upsertTopic(testTopic)
require.NoError(t, err)

// Removing validator from relayer0 to allow it to send invalid messages
err = relay[0].pubsub.UnregisterTopicValidator(testTopic)
require.NoError(t, err)

for i := 0; i < 50; i++ {
buf := make([]byte, 1000)
_, err := rand.Read(buf)
Expand Down
2 changes: 1 addition & 1 deletion waku/v2/protocol/rln/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const acceptableRootWindowSize = 5

type RegistrationHandler = func(tx *types.Transaction)

type SpamHandler = func(message *pb.WakuMessage) error
type SpamHandler = func(msg *pb.WakuMessage, topic string) error

func toRLNSignal(wakuMessage *pb.WakuMessage) []byte {
if wakuMessage == nil {
Expand Down
Loading

0 comments on commit 7beaa3f

Please sign in to comment.