diff --git a/ctx.go b/ctx.go index 330a45e6..4518bcfa 100644 --- a/ctx.go +++ b/ctx.go @@ -9,20 +9,15 @@ import ( // ProxyCtx is the Proxy context, contains useful information about every request. It is passed to // every user function. Also used as a logger. type ProxyCtx struct { - // Will contain the client request from the proxy - Req *http.Request - // Will contain the remote server's response (if available. nil if the request wasn't send yet) - Resp *http.Response + Req *http.Request // Client request to the proxy + Resp *http.Response // Remote server's response (nil if the request wasn't send yet) + Websocket bool // true if Connection is a Websocket RoundTripper RoundTripper - // will contain the recent error that occurred while trying to send receive or parse traffic - Error error - // A handle for the user to keep data in the context, from the call of ReqHandler to the - // call of RespHandler - UserData interface{} - // Will connect a request to a response - Session int64 - signer func(ca *tls.Certificate, hostname []string) (*tls.Certificate, error) - proxy *ProxyHttpServer + Error error // The recent error that occurred while trying to send receive or parse traffic + UserData interface{} // User data kept in the context, from the call of ReqHandler to the call of RespHandler + Session int64 // Invariant from a request to a response + signer func(ca *tls.Certificate, hostname []string) (*tls.Certificate, error) + proxy *ProxyHttpServer } type RoundTripper interface { diff --git a/https.go b/https.go index 123d6cbb..eef3c81c 100644 --- a/https.go +++ b/https.go @@ -141,58 +141,37 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque todo.Hijack(r, proxyClient, ctx) case ConnectHTTPMitm: - targetSite, err := proxy.connectDial("tcp", host) - if err != nil { - ctx.Warnf("Error dialing to %s: %s", host, err.Error()) - return - } - proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") client := bufio.NewReader(proxyClient) - remote := bufio.NewReader(targetSite) var ( - req *http.Request - resp *http.Response + req *http.Request + err error ) for { - // 1. read the request from the client req, err = http.ReadRequest(client) if err != nil { if err != io.EOF { - ctx.Warnf("cannot read request of MITM HTTP client: %+#v", err) + ctx.Warnf("Cannot read request of MITM HTTP client: %+#v", err) } - return + break } - // 2. filter the client request - req, resp = proxy.filterRequest(req, ctx) - if resp == nil { - if err = req.Write(targetSite); err != nil { - httpError(proxyClient, ctx, err) - return - } - resp, err = http.ReadResponse(remote, req) + req.URL, err = url.Parse("http://" + req.Host + req.URL.String()) + + if end, err := proxy.handleRequest(NewConnResponseWriter(proxyClient), req); end { if err != nil { - httpError(proxyClient, ctx, err) - return + ctx.Warnf("Error during serving MITM HTTP request: %+#v", err) } - } - - // 3. filter the response - resp = proxy.filterResponse(resp, ctx) - err = resp.Write(proxyClient) - resp.Body.Close() - - if err != nil { - httpError(proxyClient, ctx, err) - return + break } } + proxyClient.Close() + case ConnectMitm: // This goes in a separate goroutine, so that the net/http server won't think we're // still handling the request even after hijacking the connection. Those HTTP CONNECT @@ -218,88 +197,36 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque return } defer rawClientTls.Close() + defer proxyClient.Close() clientTls := bufio.NewReader(rawClientTls) for { - // 1. Read a request from the client. + // Read a request from the client. req, err := http.ReadRequest(clientTls) if err != nil { if err != io.EOF { ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err) - return } // EOF break - } else if req == nil { - ctx.Warnf("Empty request from mitm'd client") - return } - // 2. Setup a new ProxyCtx for the intercepted - // stream. - nctx := &ProxyCtx{ - Req: req, - Session: atomic.AddInt64(&proxy.sess, 1), - proxy: proxy, - } - - // Since we're converting the request, we need - // to carry over the original connecting IP as - // well. - req.RemoteAddr = r.RemoteAddr - nctx.Logf("req %v", r.Host) - if !httpsRegexp.MatchString(req.URL.String()) { req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) // err is handled below } - // Put the original request from the client - // into the context so is available at later - // time. - nctx.Req = req - - // 3. Filter the request. - filreq, resp := proxy.filterRequest(req, nctx) - if resp == nil { - // err is from the call to url.Parse above - if err != nil { - nctx.Warnf("Illegal URL %s", "https://"+r.Host+filreq.URL.Path) - return - } else if filreq == nil { - nctx.Warnf("Empty filtered request") - return - } - - removeProxyHeaders(nctx, filreq) - - // Send the request to the target - resp, err = nctx.RoundTrip(filreq) + if end, err := proxy.handleRequest(NewConnResponseWriter(rawClientTls), req); end { if err != nil { - nctx.Warnf("Cannot read TLS response from mitm'd server %v", err) - return + ctx.Warnf("Error during serving MITM HTTPS request: %+#v", err) } - nctx.Logf("resp %v", resp.Status) - } - - // 4. Filter the response. - filtered := proxy.filterResponse(resp, nctx) - - // 5. Write the filtered response to the client - err = filtered.Write(rawClientTls) - resp.Body.Close() - filtered.Body.Close() - if err != nil { - nctx.Warnf("Failed to write response to client: %v", err) - return - } - - if nctx.Req.Close { - nctx.Warnf("Non-persistent connection; closing") + break } } + ctx.Logf("Exiting on EOF") + proxyClient.Close() }() case ConnectProxyAuthHijack: diff --git a/proxy.go b/proxy.go index 50baa5ab..ef1cb89b 100644 --- a/proxy.go +++ b/proxy.go @@ -1,15 +1,19 @@ package goproxy import ( - "bufio" + "bytes" "crypto/tls" "io" + "io/ioutil" "log" "net" "net/http" "os" "regexp" + "sync" "sync/atomic" + + "github.com/gorilla/websocket" ) // The basic proxy type. Implements http.Handler. @@ -32,7 +36,9 @@ type ProxyHttpServer struct { ConnectDial func(network string, addr string) (net.Conn, error) // Signer can be set by consumers with their own implementation. This allows // f.e. for caching of Certificates. - Signer func(ca *tls.Certificate, hostname []string) (*tls.Certificate, error) + Signer func(ca *tls.Certificate, hostname []string) (*tls.Certificate, error) + WsServer *websocket.Upgrader + WsDialer *websocket.Dialer } var hasPort = regexp.MustCompile(`:\d+$`) @@ -50,14 +56,6 @@ func copyHeaders(dst, src http.Header, keepDestHeaders bool) { } } -func isEof(r *bufio.Reader) bool { - _, err := r.Peek(1) - if err == io.EOF { - return true - } - return false -} - func (proxy *ProxyHttpServer) filterRequest(r *http.Request, ctx *ProxyCtx) (req *http.Request, resp *http.Response) { req = r for _, h := range proxy.reqHandlers { @@ -79,83 +77,221 @@ func (proxy *ProxyHttpServer) filterResponse(respOrig *http.Response, ctx *Proxy return } -func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) { - r.RequestURI = "" // this must be reset when serving a request with the client - ctx.Logf("Sending request %v %v", r.Method, r.URL.String()) - if !ctx.proxy.KeepAcceptEncoding { - // If no Accept-Encoding header exists, Transport will add the headers it can accept - // and would wrap the response body with the relevant reader. - r.Header.Del("Accept-Encoding") - } - // curl can add that, see - // https://jdebp.eu./FGA/web-proxy-connection-header.html - r.Header.Del("Proxy-Connection") - r.Header.Del("Proxy-Authenticate") - r.Header.Del("Proxy-Authorization") - // Connection, Authenticate and Authorization are single hop Header: - // http://www.w3.org/Protocols/rfc2616/rfc2616.txt - // 14.10 Connection - // The Connection general-header field allows the sender to specify - // options that are desired for that particular connection and MUST NOT - // be communicated by proxies over further connections. - r.Header.Del("Connection") - r.Close = false +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +func writeResponse(ctx *ProxyCtx, resp *http.Response, out http.ResponseWriter) { + ctx.Logf("Copying response to client: %v (%d bytes)", resp.Status, resp.ContentLength) + + // Fancy ResponseWriter + if w, ok := out.(*connResponseWriter); ok { + // net/http: Response.Write produces invalid responses in this case, + // hacking to fix that + if resp.ContentLength == -1 { + defer resp.Body.Close() + + peek, err := ioutil.ReadAll( + io.LimitReader(resp.Body, 4*1024), + ) + + body := bytes.NewReader(peek) + + if err != nil { + ctx.Warnf("Error copying response: %s", err.Error()) + } + + if len(peek) < 4*1024 { + resp.ContentLength = int64(body.Len()) + resp.Body = ioutil.NopCloser(body) + } else { + resp.TransferEncoding = append(resp.TransferEncoding, "chunked") + resp.Body = ioutil.NopCloser(io.MultiReader( + body, + resp.Body, + )) + } + } + + if err := resp.Write(w); err != nil { + ctx.Warnf("Error copying response: %s", err.Error()) + } else { + ctx.Logf("Copied response to client") + } + + return + } + + // Standard ResponseWriter + // 1 + for k, _ := range out.Header() { + out.Header().Del(k) + } + + for k, vs := range resp.Header { + for _, v := range vs { + out.Header().Add(k, v) + } + } + + // 2 + out.WriteHeader(resp.StatusCode) + + // 3 + if nr, err := io.Copy(out, resp.Body); err != nil { + ctx.Logf("Copied %v bytes to client with error: %v", nr, err) + } else { + ctx.Logf("Copied %v bytes to client", nr) + } + + // 4 + if err := resp.Body.Close(); err != nil { + ctx.Warnf("Can't close response body: %v", err) + } } // Standard net/http function. Shouldn't be used directly, http.Serve will use it. func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == "CONNECT" { + // Can be SSL and WebSockets proxy.handleConnect(w, r) } else { - ctx := &ProxyCtx{ - Req: r, - Session: atomic.AddInt64(&proxy.sess, 1), - proxy: proxy, - } + // Common HTTP proxy + proxy.handleRequest(w, r) + } +} + +func (proxy *ProxyHttpServer) handleRequest(writer http.ResponseWriter, base *http.Request) (bool, error) { + ctx := &ProxyCtx{ + Req: base, + Session: atomic.AddInt64(&proxy.sess, 1), + Websocket: websocket.IsWebSocketUpgrade(base), + proxy: proxy, + } + // Clean-up + base.RequestURI = "" - var err error - ctx.Logf("Got request %v %v %v %v", r.URL.Path, r.Host, r.Method, r.URL.String()) - if !r.URL.IsAbs() { - proxy.NonproxyHandler.ServeHTTP(w, r) - return + if websocket.IsWebSocketUpgrade(base) { + return proxy.handleWsRequest(ctx, writer, base) + } else { + return proxy.handleHttpRequest(ctx, writer, base) + } +} + +func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, writer http.ResponseWriter, base *http.Request) (bool, error) { + var ( + req *http.Request + resp *http.Response + err error + ) + + ctx.Logf("Relying http(s) request to: %v", base.URL.String()) + + req, resp = proxy.filterRequest(base, ctx) + if resp == nil { + // Clean-up request + for _, h := range hopHeaders { + req.Header.Del(h) } - r, resp := proxy.filterRequest(r, ctx) - if resp == nil { - removeProxyHeaders(ctx, r) - resp, err = ctx.RoundTrip(r) - if err != nil { - ctx.Error = err - resp = proxy.filterResponse(nil, ctx) - if resp == nil { - ctx.Logf("error read response %v %v:", r.URL.Host, err.Error()) - http.Error(w, err.Error(), 500) - return - } - } - ctx.Logf("Received response %v", resp.Status) + // Process + resp, err = ctx.RoundTrip(req) + + // Clean-up response + for _, h := range hopHeaders { + resp.Header.Del(h) } - origBody := resp.Body - resp = proxy.filterResponse(resp, ctx) - defer origBody.Close() - ctx.Logf("Copying response to client %v [%d]", resp.Status, resp.StatusCode) - // http.ResponseWriter will take care of filling the correct response length - // Setting it now, might impose wrong value, contradicting the actual new - // body the user returned. - // We keep the original body to remove the header only if things changed. - // This will prevent problems with HEAD requests where there's no body, yet, - // the Content-Length header should be set. - if origBody != resp.Body { - resp.Header.Del("Content-Length") + } + + if err != nil { + ctx.Logf("Error reading response %v: %v", req.URL.Host, err.Error()) + + // TODO: add gateway timeout error in case of timeout + switch err { + default: + http.Error(writer, err.Error(), http.StatusBadGateway) } - copyHeaders(w.Header(), resp.Header, proxy.KeepDestinationHeaders) - w.WriteHeader(resp.StatusCode) - nr, err := io.Copy(w, resp.Body) - if err := resp.Body.Close(); err != nil { - ctx.Warnf("Can't close response body %v", err) + + return false, err + } + + body := resp.Body + defer body.Close() + + resp = proxy.filterResponse(resp, ctx) + + // http.ResponseWriter will take care of filling the correct response length + // Setting it now, might impose wrong value, contradicting the actual new + // body the user returned. + // We keep the original body to remove the header only if things changed. + // This will prevent problems with HEAD requests where there's no body, yet, + // the Content-Length header should be set. + if body != resp.Body { + resp.Header.Del("Content-Length") + } + + ctx.Logf("Received response: %v", resp.Status) + + writeResponse(ctx, resp, writer) + + return false, err +} + +// TODO: add handshake filter and message introspection +func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, writer http.ResponseWriter, base *http.Request) (bool, error) { + proto := websocket.Subprotocols(base) + wg := &sync.WaitGroup{} + + switch base.URL.Scheme { + case "http": + base.URL.Scheme = "ws" + + case "https": + base.URL.Scheme = "wss" + } + + ctx.Logf("Relying websocket connection %s with protocols: %v", base.URL.String(), proto) + + remote, resp, err := proxy.WsDialer.Dial( + base.URL.String(), + nil, + ) + + if err != nil { + if err == websocket.ErrBadHandshake { + writeResponse(ctx, resp, writer) + } else { + Error(writer, err.Error(), http.StatusBadGateway) } - ctx.Logf("Copied %v bytes to client error=%v", nr, err) + return true, err } + + client, err := proxy.WsServer.Upgrade(writer, base, nil) + if err != nil { + return true, err + } + + wg.Add(2) + + go wsRelay(ctx, remote, client, wg) + go wsRelay(ctx, client, remote, wg) + + wg.Wait() + + remote.Close() + client.Close() + + return true, nil } // NewProxyHttpServer creates and returns a proxy server, logging to stderr by default