Skip to content

Commit

Permalink
Merge pull request #55 from lxzan/testing
Browse files Browse the repository at this point in the history
v1.6.13
  • Loading branch information
lxzan authored Oct 13, 2023
2 parents 2a1e621 + 8523a6e commit 8bbe5eb
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 72 deletions.
27 changes: 7 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,50 +100,37 @@ package main
import (
"github.com/lxzan/gws"
"net/http"
"time"
)

const (
PingInterval = 10 * time.Second
PingWait = 5 * time.Second
)

func main() {
upgrader := gws.NewUpgrader(&Handler{}, &gws.ServerOption{
ReadAsyncEnabled: true,
CompressEnabled: true,
CheckUtf8Enabled: true,
Recovery: gws.Recovery,
})
http.HandleFunc("/connect", func(writer http.ResponseWriter, request *http.Request) {
socket, err := upgrader.Upgrade(writer, request)
if err != nil {
return
}
go func() {
// Blocking prevents the context from being GC.
socket.ReadLoop()
}()
})
http.ListenAndServe(":6666", nil)
http.ListenAndServe(":8000", nil)
}

type Handler struct{}

func (c *Handler) OnOpen(socket *gws.Conn) {
_ = socket.SetDeadline(time.Now().Add(PingInterval + PingWait))
type Handler struct {
gws.BuiltinEventHandler
}

func (c *Handler) OnClose(socket *gws.Conn, err error) {}

func (c *Handler) OnPing(socket *gws.Conn, payload []byte) {
_ = socket.SetDeadline(time.Now().Add(PingInterval + PingWait))
_ = socket.WritePong(nil)
_ = socket.WritePong(payload)
}

func (c *Handler) OnPong(socket *gws.Conn, payload []byte) {}

func (c *Handler) OnMessage(socket *gws.Conn, message *gws.Message) {
defer message.Close()
socket.WriteMessage(message.Opcode, message.Bytes())
_ = socket.WriteMessage(message.Opcode, message.Bytes())
}
```

Expand Down
4 changes: 2 additions & 2 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func BenchmarkStdDeCompress(b *testing.B) {
src := bytes.NewBuffer(nil)
for i := 0; i < b.N; i++ {
internal.BufferReset(src, buffer.Bytes())
_, _ = src.Write(internal.FlateTail)
_, _ = src.Write(flateTail)
resetter := fr.(flate.Resetter)
_ = resetter.Reset(src, nil)
io.CopyBuffer(io.Discard, fr, p)
Expand All @@ -169,7 +169,7 @@ func BenchmarkKlauspostDeCompress(b *testing.B) {
src := bytes.NewBuffer(nil)
for i := 0; i < b.N; i++ {
internal.BufferReset(src, buffer.Bytes())
_, _ = src.Write(internal.FlateTail)
_, _ = src.Write(flateTail)
resetter := fr.(klauspost.Resetter)
_ = resetter.Reset(src, nil)
fr.(io.WriterTo).WriteTo(io.Discard)
Expand Down
6 changes: 6 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ func NewClient(handler Event, option *ClientOption) (*Conn, *http.Response, erro
return nil, nil, err
}
if tlsEnabled {
if option.TlsConfig == nil {
option.TlsConfig = &tls.Config{}
}
if option.TlsConfig.ServerName == "" {
option.TlsConfig.ServerName = URL.Host
}
c.conn = tls.Client(c.conn, option.TlsConfig)
}

Expand Down
14 changes: 12 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,21 @@ func TestNewClientWSS(t *testing.T) {
}()

t.Run("", func(t *testing.T) {
_, _, err = NewClient(&BuiltinEventHandler{}, &ClientOption{
opts := &ClientOption{
Addr: "wss://" + addr,
TlsConfig: &tls.Config{InsecureSkipVerify: true},
})
}
_, _, err = NewClient(&BuiltinEventHandler{}, opts)
as.NoError(err)
as.Equal(addr, opts.TlsConfig.ServerName)
})

t.Run("", func(t *testing.T) {
opts := &ClientOption{
Addr: "wss://" + addr,
}
_, _, err = NewClient(&BuiltinEventHandler{}, opts)
as.Error(err)
})

t.Run("", func(t *testing.T) {
Expand Down
6 changes: 5 additions & 1 deletion compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"sync/atomic"
)

// FlateTail Add four bytes as specified in RFC
// Add final block to squelch unexpected EOF error from flate reader.
var flateTail = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}

type compressors struct {
serial uint64
size uint64
Expand Down Expand Up @@ -109,7 +113,7 @@ func (c *decompressor) Decompress(src *bytes.Buffer) (*bytes.Buffer, int, error)
c.Lock()
defer c.Unlock()

_, _ = src.Write(internal.FlateTail)
_, _ = src.Write(flateTail)
c.reset(src)
if _, err := c.fr.(io.WriterTo).WriteTo(c.b); err != nil {
return nil, 0, err
Expand Down
6 changes: 5 additions & 1 deletion examples/echo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package main

import (
"github.com/lxzan/gws"
"log"
"net/http"
)

func main() {
upgrader := gws.NewUpgrader(&Handler{}, &gws.ServerOption{
CompressEnabled: true,
CheckUtf8Enabled: true,
Recovery: gws.Recovery,
})
http.HandleFunc("/connect", func(writer http.ResponseWriter, request *http.Request) {
socket, err := upgrader.Upgrade(writer, request)
Expand All @@ -19,7 +21,9 @@ func main() {
socket.ReadLoop()
}()
})
http.ListenAndServe(":8000", nil)
log.Panic(
http.ListenAndServe(":8000", nil),
)
}

type Handler struct {
Expand Down
5 changes: 3 additions & 2 deletions init.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gws
import "github.com/lxzan/gws/internal"

var (
myPadding = frameHeader{} // 帧头填充物
binaryPool = internal.NewBufferPool() // 静态缓冲池
myPadding = frameHeader{} // 帧头填充物
binaryPool = internal.NewBufferPool() // 缓冲池
defaultLogger = new(stdLogger) // 默认日志工具
)
18 changes: 3 additions & 15 deletions internal/others.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package internal

import (
"io"
"math"
"net"
)
Expand All @@ -23,10 +22,6 @@ var (
SecWebSocketProtocol = Pair{"Sec-WebSocket-Protocol", ""}
)

// FlateTail Add four bytes as specified in RFC
// Add final block to squelch unexpected EOF error from flate reader.
var FlateTail = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}

const MagicNumber = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

const (
Expand All @@ -35,13 +30,6 @@ const (
ThresholdV3 = math.MaxUint64
)

type (
ReadLener interface {
io.Reader
Len() int
}

NetConn interface {
NetConn() net.Conn
}
)
type NetConn interface {
NetConn() net.Conn
}
31 changes: 31 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ type (
// 是否检查文本utf8编码, 关闭性能会好点
// Whether to check the text utf8 encoding, turn off the performance will be better
CheckUtf8Enabled bool

// 消息回调(OnMessage)的恢复程序
// Message callback (OnMessage) recovery program
Recovery func(logger Logger)

// 日志工具
// Logging tools
Logger Logger
}

ServerOption struct {
Expand All @@ -93,6 +101,11 @@ type (
CompressThreshold int
CompressorNum int
CheckUtf8Enabled bool
Logger Logger
Recovery func(logger Logger)

// TLS设置
TlsConfig *tls.Config

// 握手超时时间
HandshakeTimeout time.Duration
Expand Down Expand Up @@ -165,6 +178,12 @@ func initServerOption(c *ServerOption) *ServerOption {
if c.HandshakeTimeout <= 0 {
c.HandshakeTimeout = defaultHandshakeTimeout
}
if c.Logger == nil {
c.Logger = defaultLogger
}
if c.Recovery == nil {
c.Recovery = func(logger Logger) {}
}
c.CompressorNum = internal.ToBinaryNumber(c.CompressorNum)
c.deleteProtectedHeaders()

Expand All @@ -181,6 +200,8 @@ func initServerOption(c *ServerOption) *ServerOption {
CompressThreshold: c.CompressThreshold,
CheckUtf8Enabled: c.CheckUtf8Enabled,
CompressorNum: c.CompressorNum,
Recovery: c.Recovery,
Logger: c.Logger,
}
if c.config.CompressEnabled {
c.config.compressors = new(compressors).initialize(c.CompressorNum, c.config.CompressLevel)
Expand All @@ -207,6 +228,8 @@ type ClientOption struct {
CompressLevel int
CompressThreshold int
CheckUtf8Enabled bool
Logger Logger
Recovery func(logger Logger)

// 连接地址, 例如 wss://example.com/connect
// server address, eg: wss://example.com/connect
Expand Down Expand Up @@ -274,6 +297,12 @@ func initClientOption(c *ClientOption) *ClientOption {
if c.NewSessionStorage == nil {
c.NewSessionStorage = func() SessionStorage { return new(sliceMap) }
}
if c.Logger == nil {
c.Logger = defaultLogger
}
if c.Recovery == nil {
c.Recovery = func(logger Logger) {}
}
return c
}

Expand All @@ -289,6 +318,8 @@ func (c *ClientOption) getConfig() *Config {
CompressLevel: c.CompressLevel,
CompressThreshold: c.CompressThreshold,
CheckUtf8Enabled: c.CheckUtf8Enabled,
Recovery: c.Recovery,
Logger: c.Logger,
CompressorNum: 1,
}
if config.CompressEnabled {
Expand Down
4 changes: 4 additions & 0 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func validateServerOption(as *assert.Assertions, u *Upgrader) {
as.Equal(config.WriteBufferSize, option.WriteBufferSize)
as.Equal(config.CompressorNum, option.CompressorNum)
as.NotNil(config.readerPool)
as.NotNil(config.Recovery)
as.Equal(config.Logger, defaultLogger)

_, ok := u.option.NewSessionStorage().(*sliceMap)
as.True(ok)
Expand All @@ -42,6 +44,8 @@ func validateClientOption(as *assert.Assertions, option *ClientOption) {
as.Equal(config.ReadBufferSize, option.ReadBufferSize)
as.Equal(config.WriteBufferSize, option.WriteBufferSize)
as.Nil(config.readerPool)
as.NotNil(config.Recovery)
as.Equal(config.Logger, defaultLogger)

_, ok := option.NewSessionStorage().(*sliceMap)
as.True(ok)
Expand Down
30 changes: 17 additions & 13 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ func (c *Conn) readMessage() error {
}

var fin = c.fh.GetFIN()
var buf, index = binaryPool.Get(contentLength)
var buf, index = binaryPool.Get(contentLength + len(flateTail))
var p = buf.Bytes()[:contentLength]
var closer = Message{Data: buf, index: index}
defer closer.Close()

if err := internal.ReadN(c.br, p); err != nil {
return err
}
Expand All @@ -108,6 +111,9 @@ func (c *Conn) readMessage() error {

if fin && opcode != OpcodeContinuation {
*(*[]byte)(unsafe.Pointer(buf)) = p
if !compressed {
closer.Data, closer.index = nil, 0
}
return c.emitMessage(&Message{index: index, Opcode: opcode, Data: buf, compressed: compressed})
}

Expand All @@ -121,11 +127,8 @@ func (c *Conn) readMessage() error {
if !c.continuationFrame.initialized {
return internal.CloseProtocolError
}
if err := internal.WriteN(c.continuationFrame.buffer, p); err != nil {
return err
} else {
binaryPool.Put(buf, index)
}

c.continuationFrame.buffer.Write(p)
if c.continuationFrame.buffer.Len() > c.config.ReadMaxPayloadSize {
return internal.CloseMessageTooLarge
}
Expand All @@ -138,23 +141,24 @@ func (c *Conn) readMessage() error {
return c.emitMessage(msg)
}

func (c *Conn) dispatch(msg *Message) error {
defer c.config.Recovery(c.config.Logger)
c.handler.OnMessage(c, msg)
return nil
}

func (c *Conn) emitMessage(msg *Message) (err error) {
if msg.compressed {
data, index := msg.Data, msg.index
msg.Data, msg.index, err = c.decompressor.Decompress(msg.Data)
binaryPool.Put(data, index)
if err != nil {
return internal.NewError(internal.CloseInternalServerErr, err)
}
}
if !c.isTextValid(msg.Opcode, msg.Bytes()) {
return internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding)
}

if c.config.ReadAsyncEnabled {
c.readQueue.Go(func() { c.handler.OnMessage(c, msg) })
} else {
c.handler.OnMessage(c, msg)
return c.readQueue.Go(msg, c.dispatch)
}
return nil
return c.dispatch(msg)
}
Loading

0 comments on commit 8bbe5eb

Please sign in to comment.