From 4c0c862dcd7dec312ef336c58c73dcb0e377b0d7 Mon Sep 17 00:00:00 2001 From: Marco Debus <81422785+dadebue@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:10:02 +0200 Subject: [PATCH] Fix for unlimited maximum message expiry interval (#315) * fix when no max msg expiry interval is set * fix expiry handling of clearExpiredInflights * Modify it to handle cases where the MaximumMessageExpiryInterval is set to 0 or math.MaxInt64 for no expiry, and optimize some of the code and test cases. * Set MaximumMessageExpiryInterval to 0 or math.MaxInt64 for no expiration, and optimize some of the code and test cases. * Addressing the issue of numeric overflow with expiration values. * Only when server.Options.Capabilities.MaximumMessageExpiryInterval is set to math.MaxInt64 for no expiry. * fix typo in README.md * There is no need to verify whether 'maximumExpiry' is 'math.MaxInt64' within 'client.ClearInflight() * Optimize the code to make it easier to understand. * Differentiate the handling of 'expire' in MQTTv5 and MQTTv3; skip expiration checks if MaximumMessageExpiryInterval is set to 0; optimize code and test cases. * When MaximumMessageExpiryInterval is set to 0, it should not affect the message's own expiration(for v5) evaluation. * Adding client.ClearExpiredInflights() to clear expired messages, while client.ClearInflights() is used to clear all inflight messages. --------- Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com> Co-authored-by: werben Co-authored-by: werben --- README-CN.md | 2 +- README.md | 2 +- clients.go | 20 ++++++++++++++++++-- clients_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++------ server.go | 17 ++++++++++++----- server_test.go | 28 +++++++++++++++++++++++----- 6 files changed, 98 insertions(+), 20 deletions(-) diff --git a/README-CN.md b/README-CN.md index 20db9ac9..684ca190 100644 --- a/README-CN.md +++ b/README-CN.md @@ -183,7 +183,7 @@ server := mqtt.New(&mqtt.Options{ 关于决定默认配置的值,在这里进行一些说明: -- 默认情况下,server.Options.Capabilities.MaximumMessageExpiryInterval 的值被设置为 86400(24小时),以防止在使用默认配置时网络上暴露服务器而受到恶意DOS攻击(如果不配置到期时间将允许无限数量的保留retained/待发送inflight消息累积)。如果您在一个受信任的环境中运行,或者您有更大的保留期容量,您可以选择覆盖此设置(设置为 0 或 math.MaxInt 以取消到期限制)。 +- 默认情况下,server.Options.Capabilities.MaximumMessageExpiryInterval 的值被设置为 86400(24小时),以防止在使用默认配置时网络上暴露服务器而受到恶意DOS攻击(如果不配置到期时间将允许无限数量的保留retained/待发送inflight消息累积)。如果您在一个受信任的环境中运行,或者您有更大的保留期容量,您可以选择覆盖此设置(设置为0 以取消到期限制)。 ## 事件钩子(Event Hooks) diff --git a/README.md b/README.md index bf93a39d..098a31dd 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for Some choices were made when deciding the default configuration that need to be mentioned here: -- By default, the value of `server.Options.Capabilities.MaximumMessageExpiryInterval` is set to 86400 (24 hours), in order to prevent exposing the broker to DOS attacks on hostile networks when using the out-of-the-box configuration (as an infinite expiry would allow an infinite number of retained/inflight messages to accumulate). If you are operating in a trusted environment, or you have capacity for a larger retention period, uou may wish to override this (set to `0` or `math.MaxInt` for no expiry). +- By default, the value of `server.Options.Capabilities.MaximumMessageExpiryInterval` is set to 86400 (24 hours), in order to prevent exposing the broker to DOS attacks on hostile networks when using the out-of-the-box configuration (as an infinite expiry would allow an infinite number of retained/inflight messages to accumulate). If you are operating in a trusted environment, or you have capacity for a larger retention period, you may wish to override this (set to `0` for no expiry). ## Event Hooks A universal event hooks system allows developers to hook into various parts of the server and client life cycle to add and modify functionality of the broker. These universal hooks are used to provide everything from authentication, persistent storage, to debugging tools. diff --git a/clients.go b/clients.go index 6d5ff9a0..f5fafd74 100644 --- a/clients.go +++ b/clients.go @@ -327,10 +327,26 @@ func (cl *Client) ResendInflightMessages(force bool) error { } // ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session. -func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 { +func (cl *Client) ClearInflights() { + for _, tk := range cl.State.Inflight.GetAll(false) { + if ok := cl.State.Inflight.Delete(tk.PacketID); ok { + cl.ops.hooks.OnQosDropped(cl, tk) + atomic.AddInt64(&cl.ops.info.Inflight, -1) + } + } +} + +// ClearExpiredInflights deletes any inflight messages which have expired. +func (cl *Client) ClearExpiredInflights(now, maximumExpiry int64) []uint16 { deleted := []uint16{} for _, tk := range cl.State.Inflight.GetAll(false) { - if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now { + expired := tk.ProtocolVersion == 5 && tk.Expiry > 0 && tk.Expiry < now // [MQTT-3.3.2-5] + + // If the maximum message expiry interval is set (greater than 0), and the message + // retention period exceeds the maximum expiry, the message will be forcibly removed. + enforced := maximumExpiry > 0 && now-tk.Created > maximumExpiry + + if expired || enforced { if ok := cl.State.Inflight.Delete(tk.PacketID); ok { cl.ops.hooks.OnQosDropped(cl, tk) atomic.AddInt64(&cl.ops.info.Inflight, -1) diff --git a/clients_test.go b/clients_test.go index e992b5dc..e77e3c50 100644 --- a/clients_test.go +++ b/clients_test.go @@ -302,19 +302,56 @@ func TestClientNextPacketIDOverflow(t *testing.T) { func TestClientClearInflights(t *testing.T) { cl, _, _ := newTestClient() + n := time.Now().Unix() + + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n}) + + require.Equal(t, 5, cl.State.Inflight.Len()) + cl.ClearInflights() + require.Equal(t, 0, cl.State.Inflight.Len()) +} + +func TestClientClearExpiredInflights(t *testing.T) { + cl, _, _ := newTestClient() n := time.Now().Unix() - cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) - cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2}) - cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds - cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit - cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n}) require.Equal(t, 5, cl.State.Inflight.Len()) - deleted := cl.ClearInflights(n, 4) + deleted := cl.ClearExpiredInflights(n, 4) require.Len(t, deleted, 3) require.ElementsMatch(t, []uint16{1, 2, 5}, deleted) require.Equal(t, 2, cl.State.Inflight.Len()) + + cl.State.Inflight.Set(packets.Packet{PacketID: 11, Expiry: n - 1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 12, Expiry: n - 2}) // expiry is ineffective for v3. + cl.State.Inflight.Set(packets.Packet{PacketID: 13, Created: n - 3}) // within bounds for v3 + cl.State.Inflight.Set(packets.Packet{PacketID: 15, Created: n - 5}) // over max server expiry limit + require.Equal(t, 6, cl.State.Inflight.Len()) + + deleted = cl.ClearExpiredInflights(n, 4) + require.Len(t, deleted, 3) + require.ElementsMatch(t, []uint16{11, 12, 15}, deleted) + require.Equal(t, 3, cl.State.Inflight.Len()) + + cl.State.Inflight.Set(packets.Packet{PacketID: 17, Created: n - 1}) + deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not process abandon messages + require.Len(t, deleted, 0) + require.Equal(t, 4, cl.State.Inflight.Len()) + + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 18, Expiry: n - 1}) + deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not abandon messages + require.ElementsMatch(t, []uint16{18}, deleted) // expiry is still effective for v5. + require.Len(t, deleted, 1) + require.Equal(t, 4, cl.State.Inflight.Len()) } func TestClientResendInflightMessages(t *testing.T) { diff --git a/server.go b/server.go index 8b7b46da..0086d4b5 100644 --- a/server.go +++ b/server.go @@ -400,7 +400,7 @@ func (s *Server) attachClient(cl *Client, listener string) error { s.hooks.OnDisconnect(cl, err, expire) if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 { - cl.ClearInflights(math.MaxInt64, 0) + cl.ClearInflights() s.UnsubscribeClient(cl) s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23] } @@ -478,7 +478,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { _ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] s.UnsubscribeClient(existing) - existing.ClearInflights(math.MaxInt64, 0) + existing.ClearInflights() atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred return false // [MQTT-3.2.2-3] } @@ -503,7 +503,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { // Clean the state of the existing client to prevent sequential take-overs // from increasing memory usage by inflights + subs * client-id. s.UnsubscribeClient(existing) - existing.ClearInflights(math.MaxInt64, 0) + existing.ClearInflights() s.Log.Debug("session taken over", "client", cl.ID, "old_remote", existing.Net.Remote, "new_remote", cl.Net.Remote) @@ -1597,7 +1597,14 @@ func (s *Server) clearExpiredClients(dt int64) { // clearExpiredRetainedMessage deletes retained messages from topics if they have expired. func (s *Server) clearExpiredRetainedMessages(now int64) { for filter, pk := range s.Topics.Retained.GetAll() { - if (pk.Expiry > 0 && pk.Expiry < now) || pk.Created+s.Options.Capabilities.MaximumMessageExpiryInterval < now { + expired := pk.ProtocolVersion == 5 && pk.Expiry > 0 && pk.Expiry < now // [MQTT-3.3.2-5] + + // If the maximum message expiry interval is set (greater than 0), and the message + // retention period exceeds the maximum expiry, the message will be forcibly removed. + enforced := s.Options.Capabilities.MaximumMessageExpiryInterval > 0 && + now-pk.Created > s.Options.Capabilities.MaximumMessageExpiryInterval + + if expired || enforced { s.Topics.Retained.Delete(filter) s.hooks.OnRetainedExpired(filter) } @@ -1607,7 +1614,7 @@ func (s *Server) clearExpiredRetainedMessages(now int64) { // clearExpiredInflights deletes any inflight messages which have expired. func (s *Server) clearExpiredInflights(now int64) { for _, client := range s.Clients.GetAll() { - if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 { + if deleted := client.ClearExpiredInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 { for _, id := range deleted { s.hooks.OnQosDropped(client, packets.Packet{PacketID: id}) } diff --git a/server_test.go b/server_test.go index ec0dc3b6..604aaae6 100644 --- a/server_test.go +++ b/server_test.go @@ -3259,6 +3259,11 @@ func TestServerClearExpiredInflights(t *testing.T) { s.clearExpiredInflights(n) require.Len(t, cl.State.Inflight.GetAll(false), 2) require.Equal(t, int64(-3), s.Info.Inflight) + + s.Options.Capabilities.MaximumMessageExpiryInterval = 0 + cl.State.Inflight.Set(packets.Packet{PacketID: 8, Expiry: n - 8}) + s.clearExpiredInflights(n) + require.Len(t, cl.State.Inflight.GetAll(false), 3) } func TestServerClearExpiredRetained(t *testing.T) { @@ -3267,15 +3272,28 @@ func TestServerClearExpiredRetained(t *testing.T) { s.Options.Capabilities.MaximumMessageExpiryInterval = 4 n := time.Now().Unix() - s.Topics.Retained.Add("a/b/c", packets.Packet{Created: n, Expiry: n - 1}) - s.Topics.Retained.Add("d/e/f", packets.Packet{Created: n, Expiry: n - 2}) - s.Topics.Retained.Add("g/h/i", packets.Packet{Created: n - 3}) // within bounds - s.Topics.Retained.Add("j/k/l", packets.Packet{Created: n - 5}) // over max server expiry limit - s.Topics.Retained.Add("m/n/o", packets.Packet{Created: n}) + s.Topics.Retained.Add("a/b/c", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 1}) + s.Topics.Retained.Add("d/e/f", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 2}) + s.Topics.Retained.Add("g/h/i", packets.Packet{ProtocolVersion: 5, Created: n - 3}) // within bounds + s.Topics.Retained.Add("j/k/l", packets.Packet{ProtocolVersion: 5, Created: n - 5}) // over max server expiry limit + s.Topics.Retained.Add("m/n/o", packets.Packet{ProtocolVersion: 5, Created: n}) require.Len(t, s.Topics.Retained.GetAll(), 5) s.clearExpiredRetainedMessages(n) require.Len(t, s.Topics.Retained.GetAll(), 2) + + s.Topics.Retained.Add("p/q/r", packets.Packet{Created: n, Expiry: n - 1}) + s.Topics.Retained.Add("s/t/u", packets.Packet{Created: n, Expiry: n - 2}) // expiry is ineffective for v3. + s.Topics.Retained.Add("v/w/x", packets.Packet{Created: n - 3}) // within bounds for v3 + s.Topics.Retained.Add("y/z/1", packets.Packet{Created: n - 5}) // over max server expiry limit + require.Len(t, s.Topics.Retained.GetAll(), 6) + s.clearExpiredRetainedMessages(n) + require.Len(t, s.Topics.Retained.GetAll(), 5) + + s.Options.Capabilities.MaximumMessageExpiryInterval = 0 + s.Topics.Retained.Add("2/3/4", packets.Packet{Created: n - 8}) + s.clearExpiredRetainedMessages(n) + require.Len(t, s.Topics.Retained.GetAll(), 6) } func TestServerClearExpiredClients(t *testing.T) {