diff --git a/go.mod b/go.mod index ac3d228..5bbdbd9 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( github.com/gofrs/uuid/v5 v5.3.0 github.com/sagernet/quic-go v0.48.1-beta.1 - github.com/sagernet/sing v0.5.0 + github.com/sagernet/sing v0.6.0-alpha.12 golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 ) diff --git a/go.sum b/go.sum index 5501f1c..8313bda 100644 --- a/go.sum +++ b/go.sum @@ -21,8 +21,8 @@ github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5 github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/sagernet/quic-go v0.48.1-beta.1 h1:ElPaV5yzlXIKZpqFMAcUGax6vddi3zt4AEpT94Z0vwo= github.com/sagernet/quic-go v0.48.1-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/+or9YMLaG5VeTk4k= -github.com/sagernet/sing v0.5.0 h1:soo2wVwLcieKWWKIksFNK6CCAojUgAppqQVwyRYGkEM= -github.com/sagernet/sing v0.5.0/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.0-alpha.12 h1:RqTvSLcgnpcAVz+jzW9UE4IdqUMIxMJwZKRt+d6XDnU= +github.com/sagernet/sing v0.6.0-alpha.12/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/hysteria/client.go b/hysteria/client.go index 2c602cb..e2c091c 100644 --- a/hysteria/client.go +++ b/hysteria/client.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-quic" hyCC "github.com/sagernet/sing-quic/hysteria/congestion" "github.com/sagernet/sing/common/baderror" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" @@ -22,6 +23,29 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) +var ( + _ N.Dialer = (*Client)(nil) + _ N.PayloadDialer = (*Client)(nil) +) + +type Client struct { + ctx context.Context + dialer N.Dialer + logger logger.Logger + brutalDebug bool + serverAddr M.Socksaddr + sendBPS uint64 + receiveBPS uint64 + xplusPassword string + password string + tlsConfig aTLS.Config + quicConfig *quic.Config + udpDisabled bool + + connAccess sync.RWMutex + conn *clientQUICConnection +} + type ClientOptions struct { Context context.Context Dialer N.Dialer @@ -42,24 +66,6 @@ type ClientOptions struct { DisableMTUDiscovery bool } -type Client struct { - ctx context.Context - dialer N.Dialer - logger logger.Logger - brutalDebug bool - serverAddr M.Socksaddr - sendBPS uint64 - receiveBPS uint64 - xplusPassword string - password string - tlsConfig aTLS.Config - quicConfig *quic.Config - udpDisabled bool - - connAccess sync.RWMutex - conn *clientQUICConnection -} - func NewClient(options ClientOptions) (*Client, error) { quicConfig := &quic.Config{ DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), @@ -182,19 +188,72 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { return conn, nil } -func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - conn, err := c.offer(ctx) - if err != nil { - return nil, err +func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + return c.DialPayloadContext(ctx, network, destination, nil) + case N.NetworkUDP: + packetConn, err := c.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(packetConn, destination), nil + default: + return nil, E.Cause(N.ErrUnknownNetwork, network) } - stream, err := conn.quicConn.OpenStream() - if err != nil { - return nil, err +} + +func (c *Client) DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + conn, err := c.offer(ctx) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + stream, err := conn.quicConn.OpenStreamSync(ctx) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + buffer := WriteClientRequest(ClientRequest{ + UDP: false, + Host: destination.AddrString(), + Port: destination.Port, + }, payloads) + _, err = stream.Write(buffer.Bytes()) + buffer.Release() + if err != nil { + return nil, baderror.WrapQUIC(err) + } + response, err := ReadServerResponse(stream) + if err != nil { + return nil, baderror.WrapQUIC(err) + } + if !response.OK { + return nil, E.New("remote error: ", response.Message) + } + return &clientConn{ + Stream: stream, + destination: destination, + }, nil + case N.NetworkUDP: + packetConn, err := c.ListenPacket(ctx, destination) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + for _, payload := range payloads { + _, err = packetConn.WriteTo(payload.Bytes(), destination) + payload.Release() + if err != nil { + return nil, E.Cause(err, "write payload") + } + } + return bufio.NewBindPacketConn(packetConn, destination), nil + default: + return nil, E.Cause(N.ErrUnknownNetwork, network) } - return &clientConn{ - Stream: stream, - destination: destination, - }, nil } func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { @@ -306,48 +365,15 @@ func (c *clientQUICConnection) closeWithError(err error) { type clientConn struct { quic.Stream - destination M.Socksaddr - requestWritten bool - responseRead bool -} - -func (c *clientConn) NeedHandshake() bool { - return !c.requestWritten + destination M.Socksaddr } func (c *clientConn) Read(p []byte) (n int, err error) { - if c.responseRead { - n, err = c.Stream.Read(p) - return n, baderror.WrapQUIC(err) - } - response, err := ReadServerResponse(c.Stream) - if err != nil { - return 0, baderror.WrapQUIC(err) - } - if !response.OK { - err = E.New("remote error: ", response.Message) - return - } - c.responseRead = true n, err = c.Stream.Read(p) return n, baderror.WrapQUIC(err) } func (c *clientConn) Write(p []byte) (n int, err error) { - if !c.requestWritten { - buffer := WriteClientRequest(ClientRequest{ - UDP: false, - Host: c.destination.AddrString(), - Port: c.destination.Port, - }, p) - defer buffer.Release() - _, err = c.Stream.Write(buffer.Bytes()) - if err != nil { - return - } - c.requestWritten = true - return len(p), nil - } n, err = c.Stream.Write(p) return n, baderror.WrapQUIC(err) } @@ -357,7 +383,7 @@ func (c *clientConn) LocalAddr() net.Addr { } func (c *clientConn) RemoteAddr() net.Addr { - return M.Socksaddr{} + return c.destination } func (c *clientConn) Close() error { diff --git a/hysteria/protocol.go b/hysteria/protocol.go index 9174cc7..df299b7 100644 --- a/hysteria/protocol.go +++ b/hysteria/protocol.go @@ -178,13 +178,13 @@ func ReadClientRequest(stream io.Reader) (*ClientRequest, error) { return &clientRequest, nil } -func WriteClientRequest(request ClientRequest, payload []byte) *buf.Buffer { +func WriteClientRequest(request ClientRequest, payloads []*buf.Buffer) *buf.Buffer { var requestLen int requestLen += 1 // udp requestLen += 2 // host len requestLen += len(request.Host) requestLen += 2 // port - buffer := buf.NewSize(requestLen + len(payload)) + buffer := buf.NewSize(requestLen + buf.LenMulti(payloads)) if request.UDP { common.Must(buffer.WriteByte(1)) } else { @@ -194,8 +194,11 @@ func WriteClientRequest(request ClientRequest, payload []byte) *buf.Buffer { binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))), common.Error(buffer.WriteString(request.Host)), binary.Write(buffer, binary.BigEndian, request.Port), - common.Error(buffer.Write(payload)), ) + for _, payload := range payloads { + common.Must1(buffer.Write(payload.Bytes())) + payload.Release() + } return buffer } diff --git a/hysteria/service.go b/hysteria/service.go index 725ffc3..2018a34 100644 --- a/hysteria/service.go +++ b/hysteria/service.go @@ -330,7 +330,7 @@ type serverConn struct { func (c *serverConn) HandshakeFailure(err error) error { if c.responseWritten { - return os.ErrClosed + return os.ErrInvalid } c.responseWritten = true return WriteServerResponse(c.Stream, ServerResponse{ diff --git a/hysteria2/client.go b/hysteria2/client.go index 981f123..91dca94 100644 --- a/hysteria2/client.go +++ b/hysteria2/client.go @@ -21,6 +21,7 @@ import ( hyCC "github.com/sagernet/sing-quic/hysteria/congestion" "github.com/sagernet/sing-quic/hysteria2/internal/protocol" "github.com/sagernet/sing/common/baderror" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" @@ -30,19 +31,10 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) -type ClientOptions struct { - Context context.Context - Dialer N.Dialer - Logger logger.Logger - BrutalDebug bool - ServerAddress M.Socksaddr - SendBPS uint64 - ReceiveBPS uint64 - SalamanderPassword string - Password string - TLSConfig aTLS.Config - UDPDisabled bool -} +var ( + _ N.Dialer = (*Client)(nil) + _ N.PayloadDialer = (*Client)(nil) +) type Client struct { ctx context.Context @@ -62,6 +54,20 @@ type Client struct { conn *clientQUICConnection } +type ClientOptions struct { + Context context.Context + Dialer N.Dialer + Logger logger.Logger + BrutalDebug bool + ServerAddress M.Socksaddr + SendBPS uint64 + ReceiveBPS uint64 + SalamanderPassword string + Password string + TLSConfig aTLS.Config + UDPDisabled bool +} + func NewClient(options ClientOptions) (*Client, error) { quicConfig := &quic.Config{ DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), @@ -184,22 +190,72 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { return conn, nil } -func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - conn, err := c.offer(ctx) - if err != nil { - return nil, err +func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + return c.DialPayloadContext(ctx, network, destination, nil) + case N.NetworkUDP: + packetConn, err := c.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(packetConn, destination), nil + default: + return nil, E.Cause(N.ErrUnknownNetwork, network) } - stream, err := conn.quicConn.OpenStream() - if err != nil { - return nil, err +} + +func (c *Client) DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + conn, err := c.offer(ctx) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + stream, err := conn.quicConn.OpenStreamSync(ctx) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + buffer := protocol.WriteTCPRequest(destination.String(), payloads) + defer buffer.Release() + _, err = stream.Write(buffer.Bytes()) + if err != nil { + return nil, baderror.WrapQUIC(err) + } + status, errorMessage, err := protocol.ReadTCPResponse(stream) + if err != nil { + return nil, baderror.WrapQUIC(err) + } + if !status { + return nil, E.New("remote error: ", errorMessage) + } + return &clientConn{ + Stream: stream, + destination: destination, + }, nil + case N.NetworkUDP: + packetConn, err := c.ListenPacket(ctx, destination) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + for _, payload := range payloads { + _, err = packetConn.WriteTo(payload.Bytes(), destination) + payload.Release() + if err != nil { + buf.ReleaseMulti(payloads) + return nil, E.Cause(err, "write payload") + } + } + return bufio.NewBindPacketConn(packetConn, destination), nil + default: + return nil, E.Cause(N.ErrUnknownNetwork, network) } - return &clientConn{ - Stream: stream, - destination: destination, - }, nil } -func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { +func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { if c.udpDisabled { return nil, os.ErrInvalid } @@ -270,44 +326,15 @@ func (c *clientQUICConnection) closeWithError(err error) { type clientConn struct { quic.Stream - destination M.Socksaddr - requestWritten bool - responseRead bool -} - -func (c *clientConn) NeedHandshake() bool { - return !c.requestWritten + destination M.Socksaddr } func (c *clientConn) Read(p []byte) (n int, err error) { - if c.responseRead { - n, err = c.Stream.Read(p) - return n, baderror.WrapQUIC(err) - } - status, errorMessage, err := protocol.ReadTCPResponse(c.Stream) - if err != nil { - return 0, baderror.WrapQUIC(err) - } - if !status { - err = E.New("remote error: ", errorMessage) - return - } - c.responseRead = true n, err = c.Stream.Read(p) return n, baderror.WrapQUIC(err) } func (c *clientConn) Write(p []byte) (n int, err error) { - if !c.requestWritten { - buffer := protocol.WriteTCPRequest(c.destination.String(), p) - defer buffer.Release() - _, err = c.Stream.Write(buffer.Bytes()) - if err != nil { - return - } - c.requestWritten = true - return len(p), nil - } n, err = c.Stream.Write(p) return n, baderror.WrapQUIC(err) } @@ -317,10 +344,14 @@ func (c *clientConn) LocalAddr() net.Addr { } func (c *clientConn) RemoteAddr() net.Addr { - return M.Socksaddr{} + return c.destination } func (c *clientConn) Close() error { c.Stream.CancelRead(0) return c.Stream.Close() } + +func (c *clientConn) Upstream() any { + return c.Stream +} diff --git a/hysteria2/internal/protocol/proxy.go b/hysteria2/internal/protocol/proxy.go index 3b0c2f1..1219fa6 100644 --- a/hysteria2/internal/protocol/proxy.go +++ b/hysteria2/internal/protocol/proxy.go @@ -66,21 +66,24 @@ func ReadTCPRequest(r io.Reader) (string, error) { return string(addrBuf), nil } -func WriteTCPRequest(addr string, payload []byte) *buf.Buffer { +func WriteTCPRequest(addr string, payloads []*buf.Buffer) *buf.Buffer { padding := tcpRequestPadding.String() paddingLen := len(padding) addrLen := len(addr) sz := int(quicvarint.Len(FrameTypeTCPRequest)) + int(quicvarint.Len(uint64(addrLen))) + addrLen + int(quicvarint.Len(uint64(paddingLen))) + paddingLen - buffer := buf.NewSize(sz + len(payload)) + buffer := buf.NewSize(sz + buf.LenMulti(payloads)) bufferContent := buffer.Extend(sz) i := varintPut(bufferContent, FrameTypeTCPRequest) i += varintPut(bufferContent[i:], uint64(addrLen)) i += copy(bufferContent[i:], addr) i += varintPut(bufferContent[i:], uint64(paddingLen)) copy(bufferContent[i:], padding) - buffer.Write(payload) + for _, payload := range payloads { + buffer.Write(payload.Bytes()) + payload.Release() + } return buffer } diff --git a/hysteria2/service.go b/hysteria2/service.go index d9c6eb6..3095cbe 100644 --- a/hysteria2/service.go +++ b/hysteria2/service.go @@ -299,7 +299,7 @@ type serverConn struct { func (c *serverConn) HandshakeFailure(err error) error { if c.responseWritten { - return os.ErrClosed + return os.ErrInvalid } c.responseWritten = true buffer := protocol.WriteTCPResponse(false, err.Error(), nil) diff --git a/tuic/client.go b/tuic/client.go index 1ad12d3..a287945 100644 --- a/tuic/client.go +++ b/tuic/client.go @@ -20,18 +20,10 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) -type ClientOptions struct { - Context context.Context - Dialer N.Dialer - ServerAddress M.Socksaddr - TLSConfig aTLS.Config - UUID [16]byte - Password string - CongestionControl string - UDPStream bool - ZeroRTTHandshake bool - Heartbeat time.Duration -} +var ( + _ N.Dialer = (*Client)(nil) + _ N.PayloadDialer = (*Client)(nil) +) type Client struct { ctx context.Context @@ -50,6 +42,19 @@ type Client struct { conn *clientQUICConnection } +type ClientOptions struct { + Context context.Context + Dialer N.Dialer + ServerAddress M.Socksaddr + TLSConfig aTLS.Config + UUID [16]byte + Password string + CongestionControl string + UDPStream bool + ZeroRTTHandshake bool + Heartbeat time.Duration +} + func NewClient(options ClientOptions) (*Client, error) { if options.Heartbeat == 0 { options.Heartbeat = 10 * time.Second @@ -171,23 +176,73 @@ func (c *Client) loopHeartbeats(conn *clientQUICConnection) { } } -func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - conn, err := c.offer(ctx) - if err != nil { - return nil, err +func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + return c.DialPayloadContext(ctx, network, destination, nil) + case N.NetworkUDP: + packetConn, err := c.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(packetConn, destination), nil + default: + return nil, E.Cause(N.ErrUnknownNetwork, network) } - stream, err := conn.quicConn.OpenStream() - if err != nil { - return nil, err +} + +func (c *Client) DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + conn, err := c.offer(ctx) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + stream, err := conn.quicConn.OpenStreamSync(ctx) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + request := buf.NewSize(2 + AddressSerializer.AddrPortLen(destination) + buf.LenMulti(payloads)) + defer request.Release() + request.WriteByte(Version) + request.WriteByte(CommandConnect) + common.Must(AddressSerializer.WriteAddrPort(request, destination)) + for _, payload := range payloads { + common.Must1(request.Write(payload.Bytes())) + payload.Release() + } + _, err = stream.Write(request.Bytes()) + if err != nil { + conn.closeWithError(err) + return nil, E.Cause(baderror.WrapQUIC(err), "write request") + } + return &clientConn{ + Stream: stream, + parent: conn, + destination: destination, + }, nil + case N.NetworkUDP: + packetConn, err := c.ListenPacket(ctx, destination) + if err != nil { + buf.ReleaseMulti(payloads) + return nil, err + } + for _, payload := range payloads { + _, err = packetConn.WriteTo(payload.Bytes(), destination) + payload.Release() + if err != nil { + return nil, E.Cause(err, "write payload") + } + } + return bufio.NewBindPacketConn(packetConn, destination), nil + default: + return nil, E.Cause(N.ErrUnknownNetwork, network) } - return &clientConn{ - Stream: stream, - parent: conn, - destination: destination, - }, nil } -func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { +func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { conn, err := c.offer(ctx) if err != nil { return nil, err @@ -251,13 +306,8 @@ func (c *clientQUICConnection) closeWithError(err error) { type clientConn struct { quic.Stream - parent *clientQUICConnection - destination M.Socksaddr - requestWritten bool -} - -func (c *clientConn) NeedHandshake() bool { - return !c.requestWritten + parent *clientQUICConnection + destination M.Socksaddr } func (c *clientConn) Read(b []byte) (n int, err error) { @@ -266,24 +316,6 @@ func (c *clientConn) Read(b []byte) (n int, err error) { } func (c *clientConn) Write(b []byte) (n int, err error) { - if !c.requestWritten { - request := buf.NewSize(2 + AddressSerializer.AddrPortLen(c.destination) + len(b)) - defer request.Release() - request.WriteByte(Version) - request.WriteByte(CommandConnect) - err = AddressSerializer.WriteAddrPort(request, c.destination) - if err != nil { - return - } - request.Write(b) - _, err = c.Stream.Write(request.Bytes()) - if err != nil { - c.parent.closeWithError(E.Cause(err, "create new connection")) - return 0, baderror.WrapQUIC(err) - } - c.requestWritten = true - return len(b), nil - } n, err = c.Stream.Write(b) return n, baderror.WrapQUIC(err) }