-
Notifications
You must be signed in to change notification settings - Fork 1
/
client.go
203 lines (170 loc) · 4.44 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package tcprouter
import (
"bufio"
"context"
"fmt"
"io"
"net"
"github.com/libp2p/go-yamux"
"github.com/rs/zerolog/log"
)
// Client connect to a tpc router server and opens a reverse tunnel
type Client struct {
localAddr string
localTLSAddr string
remoteAddr string
// secret used to identify the connection in the tcp router server
secret []byte
// connection to the tcp router server
remoteSession *yamux.Session
}
// NewClient creates a new TCP router client
func NewClient(secret, local, localTLS, remote string) *Client {
return &Client{
localAddr: local,
localTLSAddr: localTLS,
remoteAddr: remote,
secret: []byte(secret),
}
}
// Start starts the client by opening a connection to the router server, doing the handshake
// then start listening for incoming steam from the router server
func (c Client) Start(ctx context.Context) error {
if err := c.connectRemote(c.remoteAddr); err != nil {
return fmt.Errorf("failed to connect to TCP router server: %w", err)
}
log.Info().Msg("start handshake")
if err := c.handshake(); err != nil {
return fmt.Errorf("failed to handshake with TCP router server: %w", err)
}
log.Info().Msg("handshake done")
return c.listen(ctx)
}
func (c *Client) connectRemote(addr string) error {
if len(c.secret) == 0 {
return fmt.Errorf("no secret configured")
}
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
return err
}
// Setup client side of yamux
session, err := yamux.Client(conn, nil)
if err != nil {
panic(err)
}
c.remoteSession = session
return nil
}
func (c *Client) connectLocal(addr string) (WriteCloser, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
return nil, err
}
return conn, nil
}
func (c *Client) handshake() error {
if c.remoteSession == nil {
return fmt.Errorf("not connected")
}
h := Handshake{
MagicNr: MagicNr,
Secret: []byte(c.secret),
}
// at this point if the server refuse the handshake it will
// just close the connection which should return an error
stream, err := c.remoteSession.OpenStream()
if err != nil {
return err
}
defer stream.Close()
return h.Write(stream)
}
func (c *Client) listen(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
cCon := make(chan WriteCloser)
cErr := make(chan error)
go func(ctx context.Context, cCon chan<- WriteCloser, cErr chan<- error) {
for {
select {
case <-ctx.Done():
return
default:
conn, err := c.remoteSession.AcceptStream()
if err != nil {
cErr <- err
return
}
cCon <- WrapConn(conn)
}
}
}(ctx, cCon, cErr)
for {
select {
case <-ctx.Done():
return nil
case err := <-cErr:
return fmt.Errorf("accept connection failed: %w", err)
case remote := <-cCon:
log.Info().
Str("remote add", remote.RemoteAddr().String()).
Msg("incoming stream, connect to local application")
var err error
var local WriteCloser
br := bufio.NewReader(remote)
_, isTLS, peeked := clientHelloServerName(br)
if isTLS {
local, err = c.connectLocal(c.localTLSAddr)
} else {
local, err = c.connectLocal(c.localAddr)
}
if err != nil {
return fmt.Errorf("failed to connect to local application: %w", err)
}
incoming := GetConn(remote, peeked)
go func(remote, local WriteCloser) {
log.Info().Msg("start forwarding")
cErr := make(chan error)
go forward(local, remote, cErr)
go forward(remote, local, cErr)
err = <-cErr
if err != nil {
log.Error().Err(err).Msg("Error during forwarding: %w")
}
<-cErr
if err := remote.Close(); err != nil {
log.Error().Err(err).Msg("Error while terminating connection")
}
if err := local.Close(); err != nil {
log.Error().Err(err).Msg("Error while terminating connection")
}
}(incoming, local)
}
}
}
func forward(dst, src WriteCloser, cErr chan<- error) {
_, err := io.Copy(dst, src)
cErr <- err
if err := dst.CloseWrite(); err != nil {
log.Error().Err(err).Msgf("error closing %s", dst.RemoteAddr().String())
}
}
type wrappedCon struct {
*yamux.Stream
}
// WrapConn wraps a stream into a wrappedCon so it implements the WriteCloser interface
func WrapConn(conn *yamux.Stream) WriteCloser {
return wrappedCon{conn}
}
func (c wrappedCon) CloseWrite() error {
return c.Stream.Close()
}