Skip to content

Commit

Permalink
Update dialer implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 15, 2024
1 parent 897565c commit fb005a6
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 181 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.20
require (
github.com/gofrs/uuid/v5 v5.3.0
github.com/sagernet/quic-go v0.48.1-beta.1
github.com/sagernet/sing v0.5.0
github.com/sagernet/sing v0.6.0-alpha.12
golang.org/x/crypto v0.28.0
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/sagernet/quic-go v0.48.1-beta.1 h1:ElPaV5yzlXIKZpqFMAcUGax6vddi3zt4AEpT94Z0vwo=
github.com/sagernet/quic-go v0.48.1-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/+or9YMLaG5VeTk4k=
github.com/sagernet/sing v0.5.0 h1:soo2wVwLcieKWWKIksFNK6CCAojUgAppqQVwyRYGkEM=
github.com/sagernet/sing v0.5.0/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-alpha.12 h1:RqTvSLcgnpcAVz+jzW9UE4IdqUMIxMJwZKRt+d6XDnU=
github.com/sagernet/sing v0.6.0-alpha.12/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
Expand Down
154 changes: 90 additions & 64 deletions hysteria/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-quic"
hyCC "github.com/sagernet/sing-quic/hysteria/congestion"
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
Expand All @@ -22,6 +23,29 @@ import (
aTLS "github.com/sagernet/sing/common/tls"
)

var (
_ N.Dialer = (*Client)(nil)
_ N.PayloadDialer = (*Client)(nil)
)

type Client struct {
ctx context.Context
dialer N.Dialer
logger logger.Logger
brutalDebug bool
serverAddr M.Socksaddr
sendBPS uint64
receiveBPS uint64
xplusPassword string
password string
tlsConfig aTLS.Config
quicConfig *quic.Config
udpDisabled bool

connAccess sync.RWMutex
conn *clientQUICConnection
}

type ClientOptions struct {
Context context.Context
Dialer N.Dialer
Expand All @@ -42,24 +66,6 @@ type ClientOptions struct {
DisableMTUDiscovery bool
}

type Client struct {
ctx context.Context
dialer N.Dialer
logger logger.Logger
brutalDebug bool
serverAddr M.Socksaddr
sendBPS uint64
receiveBPS uint64
xplusPassword string
password string
tlsConfig aTLS.Config
quicConfig *quic.Config
udpDisabled bool

connAccess sync.RWMutex
conn *clientQUICConnection
}

func NewClient(options ClientOptions) (*Client, error) {
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
Expand Down Expand Up @@ -182,19 +188,72 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
return conn, nil
}

func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
conn, err := c.offer(ctx)
if err != nil {
return nil, err
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
return c.DialPayloadContext(ctx, network, destination, nil)
case N.NetworkUDP:
packetConn, err := c.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(packetConn, destination), nil
default:
return nil, E.Cause(N.ErrUnknownNetwork, network)
}
stream, err := conn.quicConn.OpenStream()
if err != nil {
return nil, err
}

func (c *Client) DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
conn, err := c.offer(ctx)
if err != nil {
buf.ReleaseMulti(payloads)
return nil, err
}
stream, err := conn.quicConn.OpenStreamSync(ctx)
if err != nil {
buf.ReleaseMulti(payloads)
return nil, err
}
buffer := WriteClientRequest(ClientRequest{
UDP: false,
Host: destination.AddrString(),
Port: destination.Port,
}, payloads)
_, err = stream.Write(buffer.Bytes())
buffer.Release()
if err != nil {
return nil, baderror.WrapQUIC(err)
}
response, err := ReadServerResponse(stream)
if err != nil {
return nil, baderror.WrapQUIC(err)
}
if !response.OK {
return nil, E.New("remote error: ", response.Message)
}
return &clientConn{
Stream: stream,
destination: destination,
}, nil
case N.NetworkUDP:
packetConn, err := c.ListenPacket(ctx, destination)
if err != nil {
buf.ReleaseMulti(payloads)
return nil, err
}
for _, payload := range payloads {
_, err = packetConn.WriteTo(payload.Bytes(), destination)
payload.Release()
if err != nil {
return nil, E.Cause(err, "write payload")
}
}
return bufio.NewBindPacketConn(packetConn, destination), nil
default:
return nil, E.Cause(N.ErrUnknownNetwork, network)
}
return &clientConn{
Stream: stream,
destination: destination,
}, nil
}

func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
Expand Down Expand Up @@ -306,48 +365,15 @@ func (c *clientQUICConnection) closeWithError(err error) {

type clientConn struct {
quic.Stream
destination M.Socksaddr
requestWritten bool
responseRead bool
}

func (c *clientConn) NeedHandshake() bool {
return !c.requestWritten
destination M.Socksaddr
}

func (c *clientConn) Read(p []byte) (n int, err error) {
if c.responseRead {
n, err = c.Stream.Read(p)
return n, baderror.WrapQUIC(err)
}
response, err := ReadServerResponse(c.Stream)
if err != nil {
return 0, baderror.WrapQUIC(err)
}
if !response.OK {
err = E.New("remote error: ", response.Message)
return
}
c.responseRead = true
n, err = c.Stream.Read(p)
return n, baderror.WrapQUIC(err)
}

func (c *clientConn) Write(p []byte) (n int, err error) {
if !c.requestWritten {
buffer := WriteClientRequest(ClientRequest{
UDP: false,
Host: c.destination.AddrString(),
Port: c.destination.Port,
}, p)
defer buffer.Release()
_, err = c.Stream.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(p), nil
}
n, err = c.Stream.Write(p)
return n, baderror.WrapQUIC(err)
}
Expand All @@ -357,7 +383,7 @@ func (c *clientConn) LocalAddr() net.Addr {
}

func (c *clientConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
return c.destination
}

func (c *clientConn) Close() error {
Expand Down
9 changes: 6 additions & 3 deletions hysteria/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
return &clientRequest, nil
}

func WriteClientRequest(request ClientRequest, payload []byte) *buf.Buffer {
func WriteClientRequest(request ClientRequest, payloads []*buf.Buffer) *buf.Buffer {
var requestLen int
requestLen += 1 // udp
requestLen += 2 // host len
requestLen += len(request.Host)
requestLen += 2 // port
buffer := buf.NewSize(requestLen + len(payload))
buffer := buf.NewSize(requestLen + buf.LenMulti(payloads))
if request.UDP {
common.Must(buffer.WriteByte(1))
} else {
Expand All @@ -194,8 +194,11 @@ func WriteClientRequest(request ClientRequest, payload []byte) *buf.Buffer {
binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
common.Error(buffer.WriteString(request.Host)),
binary.Write(buffer, binary.BigEndian, request.Port),
common.Error(buffer.Write(payload)),
)
for _, payload := range payloads {
common.Must1(buffer.Write(payload.Bytes()))
payload.Release()
}
return buffer
}

Expand Down
2 changes: 1 addition & 1 deletion hysteria/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ type serverConn struct {

func (c *serverConn) HandshakeFailure(err error) error {
if c.responseWritten {
return os.ErrClosed
return os.ErrInvalid
}
c.responseWritten = true
return WriteServerResponse(c.Stream, ServerResponse{
Expand Down
Loading

0 comments on commit fb005a6

Please sign in to comment.