diff --git a/internal/io/mqtt/client/client.go b/internal/io/mqtt/client/client.go index 2adc9ab369..07672bc1cd 100644 --- a/internal/io/mqtt/client/client.go +++ b/internal/io/mqtt/client/client.go @@ -41,6 +41,8 @@ type Connection struct { status atomic.Value scHandler api.StatusChangeHandler conf *ConnectionConfig + // key is the topic. Each topic will have only one connector + subscriptions map[string]*subscriptionInfo } type ConnectionConfig struct { @@ -53,8 +55,15 @@ type ConnectionConfig struct { tls *tls.Config } +type subscriptionInfo struct { + Qos byte + Handler pahoMqtt.MessageHandler +} + func CreateConnection(_ api.StreamContext) modules.Connection { - return &Connection{} + return &Connection{ + subscriptions: make(map[string]*subscriptionInfo), + } } func (conn *Connection) Provision(ctx api.StreamContext, conId string, props map[string]any) error { @@ -62,7 +71,7 @@ func (conn *Connection) Provision(ctx api.StreamContext, conId string, props map if err != nil { return err } - opts := pahoMqtt.NewClientOptions().AddBroker(c.Server).SetProtocolVersion(c.pversion).SetAutoReconnect(true).SetMaxReconnectInterval(time.Minute) + opts := pahoMqtt.NewClientOptions().AddBroker(c.Server).SetProtocolVersion(c.pversion).SetAutoReconnect(true).SetMaxReconnectInterval(connection.DefaultMaxInterval) opts = opts.SetTLSConfig(c.tls) @@ -72,7 +81,6 @@ func (conn *Connection) Provision(ctx api.StreamContext, conId string, props map if c.Password != "" { opts = opts.SetPassword(c.Password) } - opts = opts.SetClientID(c.ClientId).SetAutoReconnect(true).SetResumeSubs(true).SetMaxReconnectInterval(connection.DefaultMaxInterval) conn.status.Store(modules.ConnectionStatus{Status: api.ConnectionConnecting}) opts.OnConnect = conn.onConnect @@ -119,6 +127,12 @@ func (conn *Connection) onConnect(_ pahoMqtt.Client) { conn.scHandler(api.ConnectionConnected, "") } conn.logger.Infof("The connection to mqtt broker is established") + for topic, info := range conn.subscriptions { + err := conn.Subscribe(topic, info.Qos, info.Handler) + if err != nil { // should never happen. If happens because of connection, it will retry later + conn.logger.Errorf("Failed to subscribe topic %s: %v", topic, err) + } + } } func (conn *Connection) onConnectLost(_ pahoMqtt.Client, err error) { @@ -143,6 +157,7 @@ func (conn *Connection) DetachSub(ctx api.StreamContext, props map[string]any) { if err != nil { return } + delete(conn.subscriptions, topic) conn.Client.Unsubscribe(topic) } @@ -173,6 +188,10 @@ func (conn *Connection) Publish(topic string, qos byte, retained bool, payload a } func (conn *Connection) Subscribe(topic string, qos byte, callback pahoMqtt.MessageHandler) error { + conn.subscriptions[topic] = &subscriptionInfo{ + Qos: qos, + Handler: callback, + } token := conn.Client.Subscribe(topic, qos, callback) return handleToken(token) } diff --git a/internal/io/mqtt/source_sink_test.go b/internal/io/mqtt/source_sink_test.go index 065b90e860..be8a9a04f1 100644 --- a/internal/io/mqtt/source_sink_test.go +++ b/internal/io/mqtt/source_sink_test.go @@ -15,27 +15,39 @@ package mqtt import ( + "fmt" "testing" + "time" "github.com/lf-edge/ekuiper/contract/v2/api" + mqtt "github.com/mochi-mqtt/server/v2" + "github.com/mochi-mqtt/server/v2/hooks/auth" + "github.com/mochi-mqtt/server/v2/listeners" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/lf-edge/ekuiper/v2/internal/conf" "github.com/lf-edge/ekuiper/v2/internal/pkg/store" - "github.com/lf-edge/ekuiper/v2/internal/testx" "github.com/lf-edge/ekuiper/v2/internal/topo/topotest/mockclock" "github.com/lf-edge/ekuiper/v2/pkg/connection" "github.com/lf-edge/ekuiper/v2/pkg/mock" "github.com/lf-edge/ekuiper/v2/pkg/model" ) -func TestSourceSink(t *testing.T) { - url, cancel, err := testx.InitBroker("TestSourceSink") +func TestSourceSinkRecon(t *testing.T) { + // Create the new MQTT Server. + server := mqtt.New(nil) + // Allow all connections. + _ = server.AddHook(new(auth.AllowHook), nil) + // Create a TCP listener on a standard port. + tcp := listeners.NewTCP(listeners.Config{ID: "testcon", Address: ":2883"}) + err := server.AddListener(tcp) require.NoError(t, err) - defer func() { - cancel() + go func() { + err = server.Serve() + fmt.Println(err) }() + url := tcp.Address() dataDir, err := conf.GetDataLoc() require.NoError(t, err) require.NoError(t, store.SetupDefault(dataDir)) @@ -74,7 +86,25 @@ func TestSourceSink(t *testing.T) { "qos": 0, "topic": "demo", }, result, func() { - err := mock.RunBytesSinkCollect(sk, data, map[string]any{ + err := mock.RunBytesSinkCollect(sk, data[:1], map[string]any{ + "server": url, + "topic": "demo", + "qos": 0, + "retained": false, + }) + assert.NoError(t, err) + err = server.Close() + tcp.Close(nil) + assert.NoError(t, err) + go func() { + tcp := listeners.NewTCP(listeners.Config{Address: url}) + err := server.AddListener(tcp) + require.NoError(t, err) + err = server.Serve() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + err = mock.RunBytesSinkCollect(sk, data[1:], map[string]any{ "server": url, "topic": "demo", "qos": 0, diff --git a/pkg/mock/test_source.go b/pkg/mock/test_source.go index c64113d572..528f7fadfa 100644 --- a/pkg/mock/test_source.go +++ b/pkg/mock/test_source.go @@ -162,7 +162,7 @@ func TestSourceConnectorCompare(t *testing.T, r api.Source, props map[string]any assert.NoError(t, err) }() - ticker := time.After(60000 * time.Second) + ticker := time.After(60 * time.Second) finished := make(chan struct{}) go func() { wg.Wait()