-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support graceful shutdown #148
Open
jefftt
wants to merge
3
commits into
fiorix:master
Choose a base branch
from
jefftt:support-graceful-shutdown
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,15 @@ package diam | |
import ( | ||
"bufio" | ||
"crypto/tls" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"log" | ||
"math/rand" | ||
"net" | ||
"runtime" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
|
||
"golang.org/x/net/context" | ||
|
@@ -87,6 +90,8 @@ type conn struct { | |
tlsState *tls.ConnectionState // or nil when not using TLS | ||
writer *response // the diam.Conn exposed to handlers | ||
|
||
curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState)) | ||
|
||
mu sync.Mutex // guards the following | ||
closeNotifyc chan struct{} | ||
clientGone bool | ||
|
@@ -137,6 +142,45 @@ func (c *conn) notifyClientGone() { | |
} | ||
} | ||
|
||
// A ConnState represents the state of a client connection to a server. | ||
type ConnState int | ||
|
||
const ( | ||
// StateNew represents a new connection that is expected to | ||
// send a request immediately. Connections begin at this | ||
// state and then transition to either StateActive or | ||
// StateClosed. | ||
StateNew ConnState = iota | ||
|
||
// StateActive represents a connection that has read 1 or more | ||
// bytes of a request. | ||
// After the request is handled, the state | ||
// transitions to StateClosed, or StateIdle. | ||
StateActive | ||
|
||
// StateIdle represents a connection that has finished | ||
// handling a request and is in the keep-alive state, waiting | ||
// for a new request. Connections transition from StateIdle | ||
// to either StateActive or StateClosed. | ||
StateIdle | ||
|
||
// StateClosed represents a closed connection. | ||
// This is a terminal state. Hijacked connections do not | ||
// transition to StateClosed. | ||
StateClosed | ||
) | ||
|
||
var stateName = map[ConnState]string{ | ||
StateNew: "new", | ||
StateActive: "active", | ||
StateIdle: "idle", | ||
StateClosed: "closed", | ||
} | ||
|
||
func (c ConnState) String() string { | ||
return stateName[c] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. name, ok := stateName[c] |
||
} | ||
|
||
// Create new connection from rwc. | ||
func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { | ||
msc, isMulti := rwc.(MultistreamConn) | ||
|
@@ -157,6 +201,26 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { | |
return c, nil | ||
} | ||
|
||
func (c *conn) setState(state ConnState) { | ||
srv := c.server | ||
switch state { | ||
case StateNew: | ||
srv.trackConn(c, true) | ||
case StateClosed: | ||
srv.trackConn(c, false) | ||
} | ||
if state > 0xff || state < 0 { | ||
panic("internal error") | ||
} | ||
packedState := uint64(time.Now().Unix()<<8) | uint64(state) | ||
atomic.StoreUint64(&c.curState.atomic, packedState) | ||
} | ||
|
||
func (c *conn) getState() (state ConnState, unixSec int64) { | ||
packedState := atomic.LoadUint64(&c.curState.atomic) | ||
return ConnState(packedState & 0xff), int64(packedState >> 8) | ||
} | ||
|
||
// Read next message from connection. | ||
func (c *conn) readMessage() (m *Message, err error) { | ||
if c.server.ReadTimeout > 0 { | ||
|
@@ -185,6 +249,7 @@ func (c *conn) serve() { | |
c.rwc.RemoteAddr().String(), err, buf) | ||
} | ||
c.rwc.Close() | ||
c.setState(StateClosed) | ||
}() | ||
if tlsConn, ok := c.rwc.(*tls.Conn); ok { | ||
if err := tlsConn.Handshake(); err != nil { | ||
|
@@ -195,8 +260,10 @@ func (c *conn) serve() { | |
} | ||
for { | ||
m, err := c.readMessage() | ||
c.setState(StateActive) | ||
if err != nil { | ||
c.rwc.Close() | ||
c.setState(StateClosed) | ||
// Report errors to the channel, except EOF. | ||
if err != io.EOF && err != io.ErrUnexpectedEOF { | ||
h := c.server.Handler | ||
|
@@ -211,6 +278,7 @@ func (c *conn) serve() { | |
} | ||
// Handle messages in this goroutine. | ||
serverHandler{c.server}.ServeDIAM(c.writer, m) | ||
c.setState(StateIdle) | ||
} | ||
} | ||
|
||
|
@@ -223,6 +291,12 @@ func (c *conn) dictionary() *dict.Parser { | |
return c.server.Dict | ||
} | ||
|
||
type atomicBool int32 | ||
|
||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } | ||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } | ||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } | ||
|
||
// A response represents the server side of a diameter response. | ||
// It implements the Conn and CloseNotifier interfaces. | ||
type response struct { | ||
|
@@ -548,6 +622,10 @@ func ErrorReports() <-chan *ErrorReport { | |
return DefaultServeMux.ErrorReports() | ||
} | ||
|
||
// ErrServerClosed is returned by the Server's Serve, ListenAndServe, | ||
// methods after a call to Shutdown or Close. | ||
var ErrServerClosed = errors.New("diameter: Server closed") | ||
|
||
// Serve accepts incoming diameter connections on the listener l, | ||
// creating a new service goroutine for each. The service goroutines | ||
// read messages and then call handler to reply to them. | ||
|
@@ -567,6 +645,11 @@ type Server struct { | |
WriteTimeout time.Duration // maximum duration before timing out write of the response | ||
TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS | ||
LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to | ||
|
||
inShutdown atomicBool // true when when server is in shutdown | ||
mu sync.Mutex | ||
listeners map[*net.Listener]struct{} | ||
activeConn map[*conn]struct{} | ||
} | ||
|
||
// serverHandler delegates to either the server's Handler or DefaultServeMux. | ||
|
@@ -607,11 +690,22 @@ func (srv *Server) ListenAndServe() error { | |
// new service goroutine for each. The service goroutines read requests and | ||
// then call srv.Handler to reply to them. | ||
func (srv *Server) Serve(l net.Listener) error { | ||
l = &onceCloseListener{Listener: l} | ||
defer l.Close() | ||
|
||
if !srv.trackListener(&l, true) { | ||
return ErrServerClosed | ||
} | ||
defer srv.trackListener(&l, false) | ||
|
||
var tempDelay time.Duration // how long to sleep on accept failure | ||
for { | ||
rw, e := l.Accept() | ||
if e != nil { | ||
if srv.shuttingDown() { | ||
return ErrServerClosed | ||
} | ||
|
||
if ne, ok := e.(net.Error); ok && ne.Temporary() { | ||
if tempDelay == 0 { | ||
tempDelay = 5 * time.Millisecond | ||
|
@@ -640,11 +734,150 @@ func (srv *Server) Serve(l net.Listener) error { | |
log.Printf("srv.newConn error: %v", err) | ||
continue | ||
} else { | ||
c.setState(StateNew) | ||
go c.serve() | ||
} | ||
} | ||
} | ||
|
||
// shutdownPollIntervalMax is the max polling interval when checking | ||
// quiescence during Server.Shutdown. Polling starts with a small | ||
// interval and backs off to the max. | ||
// Ideally we could find a solution that doesn't involve polling, | ||
// but which also doesn't have a high runtime cost (and doesn't | ||
// involve any contentious mutexes), but that is left as an | ||
// exercise for the reader. | ||
const shutdownPollIntervalMax = 500 * time.Millisecond | ||
|
||
// Shutdown gracefully shuts down the server without interrupting any | ||
// active connections. Shutdown works by first closing all open | ||
// listeners, then closing all idle connections, and then waiting | ||
// indefinitely for connections to return to idle and then shut down. | ||
// | ||
// When Shutdown is called, Serve, ListenAndServe, and | ||
// ListenAndServeTLS immediately return ErrServerClosed. Make sure the | ||
// program doesn't exit and waits instead for Shutdown to return. | ||
// | ||
// Once Shutdown has been called on a server, it may not be reused; | ||
// future calls to methods such as Serve will return ErrServerClosed. | ||
func (srv *Server) Shutdown() error { | ||
srv.inShutdown.setTrue() | ||
|
||
srv.mu.Lock() | ||
lnerr := srv.closeListenersLocked() | ||
srv.mu.Unlock() | ||
|
||
pollIntervalBase := time.Millisecond | ||
nextPollInterval := func() time.Duration { | ||
// Add 10% jitter. | ||
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) | ||
// Double and clamp for next time. | ||
pollIntervalBase *= 2 | ||
if pollIntervalBase > shutdownPollIntervalMax { | ||
pollIntervalBase = shutdownPollIntervalMax | ||
} | ||
return interval | ||
} | ||
|
||
timer := time.NewTimer(nextPollInterval()) | ||
defer timer.Stop() | ||
for { | ||
if srv.closeIdleConns() && srv.numListeners() == 0 { | ||
return lnerr | ||
} | ||
select { | ||
case <-timer.C: | ||
timer.Reset(nextPollInterval()) | ||
} | ||
} | ||
} | ||
|
||
func (s *Server) numListeners() int { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
return len(s.listeners) | ||
} | ||
|
||
// closeIdleConns closes all idle connections and reports whether the | ||
// server is quiescent. | ||
func (s *Server) closeIdleConns() bool { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
quiescent := true | ||
for c := range s.activeConn { | ||
st, unixSec := c.getState() | ||
// treat StateNew connections as if | ||
// they're idle if we haven't read the first request's | ||
// header in over 5 seconds. | ||
if st == StateNew && unixSec < time.Now().Unix()-5 { | ||
st = StateIdle | ||
} | ||
if st != StateIdle || unixSec == 0 { | ||
// Assume unixSec == 0 means it's a very new | ||
// connection, without state set yet. | ||
quiescent = false | ||
continue | ||
} | ||
c.rwc.Close() | ||
delete(s.activeConn, c) | ||
} | ||
return quiescent | ||
} | ||
|
||
// trackListener adds or removes a net.Listener to the set of tracked | ||
// listeners. | ||
// | ||
// We store a pointer to interface in the map set, in case the | ||
// net.Listener is not comparable. This is safe because we only call | ||
// trackListener via Serve and can track+defer untrack the same | ||
// pointer to local variable there. We never need to compare a | ||
// Listener from another caller. | ||
// | ||
// It reports whether the server is still up (not Shutdown or Closed). | ||
func (s *Server) trackListener(ln *net.Listener, add bool) bool { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
if s.listeners == nil { | ||
s.listeners = make(map[*net.Listener]struct{}) | ||
} | ||
if add { | ||
if s.shuttingDown() { | ||
return false | ||
} | ||
s.listeners[ln] = struct{}{} | ||
} else { | ||
delete(s.listeners, ln) | ||
} | ||
return true | ||
} | ||
|
||
func (s *Server) trackConn(c *conn, add bool) { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
if s.activeConn == nil { | ||
s.activeConn = make(map[*conn]struct{}) | ||
} | ||
if add { | ||
s.activeConn[c] = struct{}{} | ||
} else { | ||
delete(s.activeConn, c) | ||
} | ||
} | ||
|
||
func (s *Server) shuttingDown() bool { | ||
return s.inShutdown.isSet() | ||
} | ||
|
||
func (s *Server) closeListenersLocked() error { | ||
var err error | ||
for ln := range s.listeners { | ||
if cerr := (*ln).Close(); cerr != nil && err == nil { | ||
err = cerr | ||
} | ||
} | ||
return err | ||
} | ||
|
||
// ListenAndServeNetwork listens on the network & addr | ||
// and then calls Serve with handler to handle requests | ||
// on incoming connections. | ||
|
@@ -729,3 +962,18 @@ func ListenAndServeNetworkTLS(network, addr string, certFile string, keyFile str | |
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler, dp *dict.Parser) error { | ||
return ListenAndServeNetworkTLS("tcp", addr, certFile, keyFile, handler, dp) | ||
} | ||
|
||
// onceCloseListener wraps a net.Listener, protecting it from | ||
// multiple Close calls. | ||
type onceCloseListener struct { | ||
net.Listener | ||
once sync.Once | ||
closeErr error | ||
} | ||
|
||
func (oc *onceCloseListener) Close() error { | ||
oc.once.Do(oc.close) | ||
return oc.closeErr | ||
} | ||
|
||
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not int8 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taken from https://github.com/golang/go/blob/master/src/net/http/server.go#L2823
looks like it can be an int8 afaict, can update