From 624dde0986e28442a1406145de637c835190b10d Mon Sep 17 00:00:00 2001 From: werben Date: Wed, 13 Dec 2023 07:45:56 +0800 Subject: [PATCH] Handle expired clients in server.loadClients(). (#341) * Handle expired clients in server.loadClients(). * No need to call s.Clients.Delete(). --------- Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com> --- server.go | 10 +++++++++- server_test.go | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index 0086d4b5..e94950f8 100644 --- a/server.go +++ b/server.go @@ -1553,7 +1553,15 @@ func (s *Server) loadClients(v []storage.Client) { MaximumPacketSize: c.Properties.MaximumPacketSize, } cl.Properties.Will = Will(c.Will) - s.Clients.Add(cl) + + expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) + s.hooks.OnDisconnect(cl, packets.ErrServerShuttingDown, expire) + if expire { + cl.ClearInflights() + s.UnsubscribeClient(cl) + } else { + s.Clients.Add(cl) + } } } diff --git a/server_test.go b/server_test.go index 604aaae6..cc2175f3 100644 --- a/server_test.go +++ b/server_test.go @@ -3128,15 +3128,50 @@ func TestServerLoadClients(t *testing.T) { {ID: "mochi"}, {ID: "zen"}, {ID: "mochi-co"}, + {ID: "v3-clean", ProtocolVersion: 4, Clean: true}, + {ID: "v3-not-clean", ProtocolVersion: 4, Clean: false}, + { + ID: "v5-clean", + ProtocolVersion: 5, + Clean: true, + Properties: storage.ClientProperties{ + SessionExpiryInterval: 10, + }, + }, + { + ID: "v5-expire-interval-0", + ProtocolVersion: 5, + Properties: storage.ClientProperties{ + SessionExpiryInterval: 0, + }, + }, + { + ID: "v5-expire-interval-not-0", + ProtocolVersion: 5, + Properties: storage.ClientProperties{ + SessionExpiryInterval: 10, + }, + }, } s := newServer() require.Equal(t, 0, s.Clients.Len()) s.loadClients(v) - require.Equal(t, 3, s.Clients.Len()) + require.Equal(t, 6, s.Clients.Len()) cl, ok := s.Clients.Get("mochi") require.True(t, ok) require.Equal(t, "mochi", cl.ID) + + _, ok = s.Clients.Get("v3-clean") + require.False(t, ok) + _, ok = s.Clients.Get("v3-not-clean") + require.True(t, ok) + _, ok = s.Clients.Get("v5-clean") + require.True(t, ok) + _, ok = s.Clients.Get("v5-expire-interval-0") + require.False(t, ok) + _, ok = s.Clients.Get("v5-expire-interval-not-0") + require.True(t, ok) } func TestServerLoadSubscriptions(t *testing.T) {