Skip to content

Commit

Permalink
exception recovery
Browse files Browse the repository at this point in the history
  • Loading branch information
lxzan committed Oct 12, 2023
1 parent 919ea14 commit afe21aa
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/echo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ func main() {
upgrader := gws.NewUpgrader(&Handler{}, &gws.ServerOption{
CompressEnabled: true,
CheckUtf8Enabled: true,
Caller: gws.Recovery(gws.StdLogger),
})
http.HandleFunc("/connect", func(writer http.ResponseWriter, request *http.Request) {
socket, err := upgrader.Upgrade(writer, request)
Expand Down
2 changes: 2 additions & 0 deletions init.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ import "github.com/lxzan/gws/internal"
var (
myPadding = frameHeader{} // 帧头填充物
binaryPool = internal.NewBufferPool() // 缓冲池

StdLogger = new(stdLogger) // 标准日志输出
)
14 changes: 14 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ type (
// 是否检查文本utf8编码, 关闭性能会好点
// Whether to check the text utf8 encoding, turn off the performance will be better
CheckUtf8Enabled bool

// OnMessage调用器, 可用于异常恢复
// OnMessage caller, can be used for exception recovery
Caller Caller
}

ServerOption struct {
Expand All @@ -93,6 +97,7 @@ type (
CompressThreshold int
CompressorNum int
CheckUtf8Enabled bool
Caller Caller

// TLS设置
TlsConfig *tls.Config
Expand Down Expand Up @@ -168,6 +173,9 @@ func initServerOption(c *ServerOption) *ServerOption {
if c.HandshakeTimeout <= 0 {
c.HandshakeTimeout = defaultHandshakeTimeout
}
if c.Caller == nil {
c.Caller = func(f func()) { f() }
}
c.CompressorNum = internal.ToBinaryNumber(c.CompressorNum)
c.deleteProtectedHeaders()

Expand All @@ -184,6 +192,7 @@ func initServerOption(c *ServerOption) *ServerOption {
CompressThreshold: c.CompressThreshold,
CheckUtf8Enabled: c.CheckUtf8Enabled,
CompressorNum: c.CompressorNum,
Caller: c.Caller,
}
if c.config.CompressEnabled {
c.config.compressors = new(compressors).initialize(c.CompressorNum, c.config.CompressLevel)
Expand All @@ -210,6 +219,7 @@ type ClientOption struct {
CompressLevel int
CompressThreshold int
CheckUtf8Enabled bool
Caller Caller

// 连接地址, 例如 wss://example.com/connect
// server address, eg: wss://example.com/connect
Expand Down Expand Up @@ -277,6 +287,9 @@ func initClientOption(c *ClientOption) *ClientOption {
if c.NewSessionStorage == nil {
c.NewSessionStorage = func() SessionStorage { return new(sliceMap) }
}
if c.Caller == nil {
c.Caller = func(f func()) { f() }
}
return c
}

Expand All @@ -292,6 +305,7 @@ func (c *ClientOption) getConfig() *Config {
CompressLevel: c.CompressLevel,
CompressThreshold: c.CompressThreshold,
CheckUtf8Enabled: c.CheckUtf8Enabled,
Caller: c.Caller,
CompressorNum: 1,
}
if config.CompressEnabled {
Expand Down
2 changes: 2 additions & 0 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func validateServerOption(as *assert.Assertions, u *Upgrader) {
as.Equal(config.WriteBufferSize, option.WriteBufferSize)
as.Equal(config.CompressorNum, option.CompressorNum)
as.NotNil(config.readerPool)
as.NotNil(config.Caller)

_, ok := u.option.NewSessionStorage().(*sliceMap)
as.True(ok)
Expand All @@ -42,6 +43,7 @@ func validateClientOption(as *assert.Assertions, option *ClientOption) {
as.Equal(config.ReadBufferSize, option.ReadBufferSize)
as.Equal(config.WriteBufferSize, option.WriteBufferSize)
as.Nil(config.readerPool)
as.NotNil(config.Caller)

_, ok := option.NewSessionStorage().(*sliceMap)
as.True(ok)
Expand Down
10 changes: 8 additions & 2 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,15 @@ func (c *Conn) emitMessage(msg *Message) (err error) {
}

if c.config.ReadAsyncEnabled {
c.readQueue.Go(func() { c.handler.OnMessage(c, msg) })
c.readQueue.Go(func() {
c.config.Caller(func() {
c.handler.OnMessage(c, msg)
})
})
} else {
c.handler.OnMessage(c, msg)
c.config.Caller(func() {
c.handler.OnMessage(c, msg)
})
}
return nil
}
35 changes: 35 additions & 0 deletions recovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package gws

import (
"log"
"runtime"
"unsafe"
)

type Logger interface {
Error(v ...any)
}

type stdLogger struct{}

func (c *stdLogger) Error(v ...any) {
log.Println(v...)
}

type Caller func(f func())

func Recovery(logger Logger) Caller {
return func(f func()) {
defer func() {
if e := recover(); e != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
msg := *(*string)(unsafe.Pointer(&buf))
logger.Error(e, msg)
}
}()

f()
}
}
17 changes: 17 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,20 @@ func (b broadcastHandler) OnMessage(socket *Conn, message *Message) {
defer message.Close()
b.wg.Done()
}

func TestRecovery(t *testing.T) {
var as = assert.New(t)
var serverHandler = new(webSocketMocker)
var clientHandler = new(webSocketMocker)
var serverOption = &ServerOption{Caller: Recovery(StdLogger)}
var clientOption = &ClientOption{}
serverHandler.onMessage = func(socket *Conn, message *Message) {
var m map[string]uint8
m[""] = 1
}
server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
go server.ReadLoop()
go client.ReadLoop()
as.NoError(client.WriteString("hi"))
time.Sleep(100 * time.Millisecond)
}

0 comments on commit afe21aa

Please sign in to comment.