Skip to content

Commit

Permalink
refactor: change server.Run() into blocking function
Browse files Browse the repository at this point in the history
  • Loading branch information
DrmagicE committed Jan 28, 2021
1 parent c4d4019 commit 4be42cd
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 40 deletions.
10 changes: 6 additions & 4 deletions cmd/gmqttd/command/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ func installSignal(srv server.Server) {
srv.ApplyConfig(c)
logger.Info("gmqtt reloaded")
case <-stopSignalCh:
srv.Stop(context.Background())
return
err := srv.Stop(context.Background())
if err != nil {
fmt.Fprint(os.Stderr, err.Error())
}
}
}

Expand Down Expand Up @@ -129,13 +131,13 @@ func NewStartCmd() *cobra.Command {
os.Exit(1)
return
}
go installSignal(s)
err = s.Run()
if err != nil {
fmt.Println(err)
fmt.Fprint(os.Stderr, err.Error())
os.Exit(1)
return
}
installSignal(s)
},
}
return cmd
Expand Down
10 changes: 6 additions & 4 deletions examples/hook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,14 @@ func main() {
server.WithLogger(l),
server.WithConfig(config.DefaultConfig()),
)
go func() {
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
<-signalCh
s.Stop(context.Background())
}()
err = s.Run()
if err != nil {
panic(err)
}
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
<-signalCh
s.Stop(context.Background())
}
24 changes: 13 additions & 11 deletions examples/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ func main() {
err = srv.Init(server.WithHook(server.Hooks{
OnConnected: func(ctx context.Context, client server.Client) {
// add subscription for a client when it is connected
subService = srv.SubscriptionService()
subService.Subscribe(client.ClientOptions().ClientID, &gmqtt.Subscription{
TopicFilter: "topic",
QoS: packets.Qos0,
})
},
}))
subService = srv.SubscriptionService()

if err != nil {
fmt.Println(err.Error())
Expand All @@ -67,12 +67,6 @@ func main() {
})

// publish service

err = srv.Run()
if err != nil {
panic(err)
}

go func() {
for {
<-time.NewTimer(5 * time.Second).C
Expand All @@ -93,8 +87,16 @@ func main() {
}

}()
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
<-signalCh
srv.Stop(context.Background())

go func() {
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
<-signalCh
srv.Stop(context.Background())
}()
err = srv.Run()
if err != nil {
panic(err)
}

}
23 changes: 14 additions & 9 deletions server/api_registrar.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,25 @@ func buildHTTPServer(endpoint *config.Endpoint) (*httpServer, error) {
}, nil
}

func (srv *server) exit() {
select {
case <-srv.exitChan:
default:
close(srv.exitChan)
}
}

func (srv *server) serveAPIServer() {
defer func() {
srv.wg.Done()
srv.Stop(context.Background())
}()
var err error
errChan := make(chan error, 1)
defer func() {
srv.wg.Done()
if err != nil {
zaplog.Error("serveAPIServer error", zap.Error(err))
srv.setError(err)
}
}()
errChan := make(chan error, 1)
defer func() {
for _, v := range srv.apiRegistrar.gRPCServers {
v.shutdown()
}
Expand Down Expand Up @@ -242,10 +250,7 @@ func (srv *server) serveAPIServer() {
select {
case <-srv.exitChan:
return
case err := <-errChan:
if err != nil {
zaplog.Error("gRPC server stop error", zap.Error(err))
}
case err = <-errChan:
return
}

Expand Down
29 changes: 17 additions & 12 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (c *clientService) TerminateSession(clientID string) {
}

// server represents a mqtt server instance.
// Create a server by using NewServer()
// Create a server by using New()
type server struct {
wg sync.WaitGroup
initOnce sync.Once
Expand All @@ -162,6 +162,8 @@ type server struct {
willMessage map[string]*willMsg
tcpListener []net.Listener //tcp listeners
websocketServer []*WsServer //websocket serverStop
errOnce sync.Once
err error
exitChan chan struct{}

retainedDB retained.Store
Expand Down Expand Up @@ -791,7 +793,6 @@ func defaultServer() *server {

// New returns a gmqtt server instance with the given options
func New(opts ...Options) *server {
// statistics
srv := defaultServer()
for _, fn := range opts {
fn(srv)
Expand Down Expand Up @@ -1006,8 +1007,8 @@ func (srv *server) serveWebSocket(ws *WsServer) {
} else {
err = ws.Server.ListenAndServe()
}
if err != http.ErrServerClosed {
panic(err.Error())
if err != nil && err != http.ErrServerClosed {
srv.setError(fmt.Errorf("serveWebSocket error: %s", err.Error()))
}
}

Expand Down Expand Up @@ -1286,7 +1287,14 @@ func (srv *server) wsHandler() http.HandlerFunc {
}
}

// Run starts the mqtt server. This method is non-blocking
func (srv *server) setError(err error) {
srv.errOnce.Do(func() {
srv.err = err
srv.exit()
})
}

// Run starts the mqtt server.
func (srv *server) Run() (err error) {
err = srv.Init()
if err != nil {
Expand Down Expand Up @@ -1315,7 +1323,8 @@ func (srv *server) Run() (err error) {
server.Server.Handler = mux
go srv.serveWebSocket(server)
}
return nil
srv.wg.Wait()
return srv.err
}

// Stop gracefully stops the mqtt server by the following steps:
Expand All @@ -1329,13 +1338,9 @@ func (srv *server) Stop(ctx context.Context) error {
zaplog.Info("server stopped")
//zaplog.Sync()
}()
select {
case <-srv.exitChan:
return nil
default:
close(srv.exitChan)
}
srv.exit()
srv.wg.Wait()

for _, l := range srv.tcpListener {
l.Close()
}
Expand Down

0 comments on commit 4be42cd

Please sign in to comment.