From 638176c8c30c137c9e87237ea76f530f2b3591cd Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Wed, 14 Sep 2016 09:32:31 +0300 Subject: [PATCH 1/8] Added gorilla/websocket dependency --- vendor/github.com/gorilla/websocket/AUTHORS | 8 + vendor/github.com/gorilla/websocket/LICENSE | 22 + vendor/github.com/gorilla/websocket/README.md | 61 ++ .../gorilla/websocket/bench_test.go | 19 + vendor/github.com/gorilla/websocket/client.go | 375 +++++++ .../gorilla/websocket/client_server_test.go | 448 ++++++++ .../gorilla/websocket/client_test.go | 72 ++ .../gorilla/websocket/compression.go | 85 ++ .../gorilla/websocket/compression_test.go | 31 + vendor/github.com/gorilla/websocket/conn.go | 994 ++++++++++++++++++ .../github.com/gorilla/websocket/conn_read.go | 18 + .../gorilla/websocket/conn_read_legacy.go | 21 + .../github.com/gorilla/websocket/conn_test.go | 402 +++++++ vendor/github.com/gorilla/websocket/doc.go | 152 +++ .../gorilla/websocket/example_test.go | 46 + vendor/github.com/gorilla/websocket/json.go | 55 + .../github.com/gorilla/websocket/json_test.go | 119 +++ vendor/github.com/gorilla/websocket/server.go | 261 +++++ .../gorilla/websocket/server_test.go | 51 + vendor/github.com/gorilla/websocket/util.go | 214 ++++ .../github.com/gorilla/websocket/util_test.go | 74 ++ vendor/vendor.json | 13 + 22 files changed, 3541 insertions(+) create mode 100644 vendor/github.com/gorilla/websocket/AUTHORS create mode 100644 vendor/github.com/gorilla/websocket/LICENSE create mode 100644 vendor/github.com/gorilla/websocket/README.md create mode 100644 vendor/github.com/gorilla/websocket/bench_test.go create mode 100644 vendor/github.com/gorilla/websocket/client.go create mode 100644 vendor/github.com/gorilla/websocket/client_server_test.go create mode 100644 vendor/github.com/gorilla/websocket/client_test.go create mode 100644 vendor/github.com/gorilla/websocket/compression.go create mode 100644 vendor/github.com/gorilla/websocket/compression_test.go create mode 100644 vendor/github.com/gorilla/websocket/conn.go create mode 100644 vendor/github.com/gorilla/websocket/conn_read.go create mode 100644 vendor/github.com/gorilla/websocket/conn_read_legacy.go create mode 100644 vendor/github.com/gorilla/websocket/conn_test.go create mode 100644 vendor/github.com/gorilla/websocket/doc.go create mode 100644 vendor/github.com/gorilla/websocket/example_test.go create mode 100644 vendor/github.com/gorilla/websocket/json.go create mode 100644 vendor/github.com/gorilla/websocket/json_test.go create mode 100644 vendor/github.com/gorilla/websocket/server.go create mode 100644 vendor/github.com/gorilla/websocket/server_test.go create mode 100644 vendor/github.com/gorilla/websocket/util.go create mode 100644 vendor/github.com/gorilla/websocket/util_test.go create mode 100644 vendor/vendor.json diff --git a/vendor/github.com/gorilla/websocket/AUTHORS b/vendor/github.com/gorilla/websocket/AUTHORS new file mode 100644 index 00000000..b003eca0 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/AUTHORS @@ -0,0 +1,8 @@ +# This is the official list of Gorilla WebSocket authors for copyright +# purposes. +# +# Please keep the list sorted. + +Gary Burd +Joachim Bauch + diff --git a/vendor/github.com/gorilla/websocket/LICENSE b/vendor/github.com/gorilla/websocket/LICENSE new file mode 100644 index 00000000..9171c972 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/gorilla/websocket/README.md b/vendor/github.com/gorilla/websocket/README.md new file mode 100644 index 00000000..9d71959e --- /dev/null +++ b/vendor/github.com/gorilla/websocket/README.md @@ -0,0 +1,61 @@ +# Gorilla WebSocket + +Gorilla WebSocket is a [Go](http://golang.org/) implementation of the +[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. + +### Documentation + +* [API Reference](http://godoc.org/github.com/gorilla/websocket) +* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) +* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) +* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) +* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) + +### Status + +The Gorilla WebSocket package provides a complete and tested implementation of +the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The +package API is stable. + +### Installation + + go get github.com/gorilla/websocket + +### Protocol Compliance + +The Gorilla WebSocket package passes the server tests in the [Autobahn Test +Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn +subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). + +### Gorilla WebSocket compared with other packages + + + + + + + + + + + + + + + + + + +
github.com/gorillagolang.org/x/net
RFC 6455 Features
Passes Autobahn Test SuiteYesNo
Receive fragmented messageYesNo, see note 1
Send close messageYesNo
Send pings and receive pongsYesNo
Get the type of a received data messageYesYes, see note 2
Other Features
Limit size of received messageYesNo
Read message using io.ReaderYesNo, see note 3
Write message using io.WriteCloserYesNo, see note 3
+ +Notes: + +1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). +2. The application can get the type of a received data message by implementing + a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal) + function. +3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries. + Read returns when the input buffer is full or a frame boundary is + encountered. Each call to Write sends a single frame message. The Gorilla + io.Reader and io.WriteCloser operate on a single WebSocket message. + diff --git a/vendor/github.com/gorilla/websocket/bench_test.go b/vendor/github.com/gorilla/websocket/bench_test.go new file mode 100644 index 00000000..f66fc36b --- /dev/null +++ b/vendor/github.com/gorilla/websocket/bench_test.go @@ -0,0 +1,19 @@ +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "testing" +) + +func BenchmarkMaskBytes(b *testing.B) { + var key [4]byte + data := make([]byte, 1024) + pos := 0 + for i := 0; i < b.N; i++ { + pos = maskBytes(key, pos, data) + } + b.SetBytes(int64(len(data))) +} diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go new file mode 100644 index 00000000..879d33ed --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client.go @@ -0,0 +1,375 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/base64" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// ErrBadHandshake is returned when the server response to opening handshake is +// invalid. +var ErrBadHandshake = errors.New("websocket: bad handshake") + +// NewClient creates a new client connection using the given net connection. +// The URL u specifies the host and request URI. Use requestHeader to specify +// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies +// (Cookie). Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etc. +// +// Deprecated: Use Dialer instead. +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, + } + return d.Dial(u.String(), requestHeader) +} + +// A Dialer contains options for connecting to WebSocket server. +type Dialer struct { + // NetDial specifies the dial function for creating TCP connections. If + // NetDial is nil, net.Dial is used. + NetDial func(network, addr string) (net.Conn, error) + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // TLSClientConfig specifies the TLS configuration to use with tls.Client. + // If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // Input and output buffer sizes. If the buffer size is zero, then a + // default value of 4096 is used. + ReadBufferSize, WriteBufferSize int + + // Subprotocols specifies the client's requested subprotocols. + Subprotocols []string +} + +var errMalformedURL = errors.New("malformed ws or wss URL") + +// parseURL parses the URL. +// +// This function is a replacement for the standard library url.Parse function. +// In Go 1.4 and earlier, url.Parse loses information from the path. +func parseURL(s string) (*url.URL, error) { + // From the RFC: + // + // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] + // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] + + var u url.URL + switch { + case strings.HasPrefix(s, "ws://"): + u.Scheme = "ws" + s = s[len("ws://"):] + case strings.HasPrefix(s, "wss://"): + u.Scheme = "wss" + s = s[len("wss://"):] + default: + return nil, errMalformedURL + } + + if i := strings.Index(s, "?"); i >= 0 { + u.RawQuery = s[i+1:] + s = s[:i] + } + + if i := strings.Index(s, "/"); i >= 0 { + u.Opaque = s[i:] + s = s[:i] + } else { + u.Opaque = "/" + } + + u.Host = s + + if strings.Contains(u.Host, "@") { + // Don't bother parsing user information because user information is + // not allowed in websocket URIs. + return nil, errMalformedURL + } + + return &u, nil +} + +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": + hostPort += ":443" + default: + hostPort += ":80" + } + } + return hostPort, hostNoPort +} + +// DefaultDialer is a dialer with all fields set to the default zero values. +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, +} + +// Dial creates a new client connection. Use requestHeader to specify the +// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). +// Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + + if d == nil { + d = &Dialer{ + Proxy: http.ProxyFromEnvironment, + } + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + + u, err := parseURL(urlStr) + if err != nil { + return nil, nil, err + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + default: + req.Header[k] = vs + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + + var proxyURL *url.URL + // Check wether the proxy method has been configured + if d.Proxy != nil { + proxyURL, err = d.Proxy(req) + } + if err != nil { + return nil, nil, err + } + + var targetHostPort string + if proxyURL != nil { + targetHostPort, _ = hostPortNoPort(proxyURL) + } else { + targetHostPort = hostPort + } + + var deadline time.Time + if d.HandshakeTimeout != 0 { + deadline = time.Now().Add(d.HandshakeTimeout) + } + + netDial := d.NetDial + if netDial == nil { + netDialer := &net.Dialer{Deadline: deadline} + netDial = netDialer.Dial + } + + netConn, err := netDial("tcp", targetHostPort) + if err != nil { + return nil, nil, err + } + + defer func() { + if netConn != nil { + netConn.Close() + } + }() + + if err := netConn.SetDeadline(deadline); err != nil { + return nil, nil, err + } + + if proxyURL != nil { + connectHeader := make(http.Header) + if user := proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: hostPort}, + Host: hostPort, + Header: connectHeader, + } + + connectReq.Write(netConn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(netConn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + return nil, nil, errors.New(f[1]) + } + } + + if u.Scheme == "https" { + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + if err := tlsConn.Handshake(); err != nil { + return nil, nil, err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return nil, nil, err + } + } + } + + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + resp, err := http.ReadResponse(conn.br, req) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake + } + + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + + netConn.SetDeadline(time.Time{}) + netConn = nil // to avoid close in defer. + return conn, resp, nil +} + +// cloneTLSConfig clones all public fields except the fields +// SessionTicketsDisabled and SessionTicketKey. This avoids copying the +// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a +// config in active use. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } +} diff --git a/vendor/github.com/gorilla/websocket/client_server_test.go b/vendor/github.com/gorilla/websocket/client_server_test.go new file mode 100644 index 00000000..1cb9b645 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client_server_test.go @@ -0,0 +1,448 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + "time" +) + +var cstUpgrader = Upgrader{ + Subprotocols: []string{"p0", "p1"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { + http.Error(w, reason.Error(), status) + }, +} + +var cstDialer = Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +type cstHandler struct{ *testing.T } + +type cstServer struct { + *httptest.Server + URL string +} + +const ( + cstPath = "/a/b" + cstRawQuery = "x=y" + cstRequestURI = cstPath + "?" + cstRawQuery +) + +func newServer(t *testing.T) *cstServer { + var s cstServer + s.Server = httptest.NewServer(cstHandler{t}) + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) + return &s +} + +func newTLSServer(t *testing.T) *cstServer { + var s cstServer + s.Server = httptest.NewTLSServer(cstHandler{t}) + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) + return &s +} + +func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != cstPath { + t.Logf("path=%v, want %v", r.URL.Path, cstPath) + http.Error(w, "bad path", 400) + return + } + if r.URL.RawQuery != cstRawQuery { + t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) + http.Error(w, "bad path", 400) + return + } + subprotos := Subprotocols(r) + if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { + t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) + http.Error(w, "bad protocol", 400) + return + } + ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) + if err != nil { + t.Logf("Upgrade: %v", err) + return + } + defer ws.Close() + + if ws.Subprotocol() != "p1" { + t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) + ws.Close() + return + } + op, rd, err := ws.NextReader() + if err != nil { + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } +} + +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + +func sendRecv(t *testing.T, ws *Conn) { + const message = "Hello World!" + if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + if string(p) != message { + t.Fatalf("message=%s, want %s", p, message) + } +} + +func TestProxyDial(t *testing.T) { + + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.URL) + + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + connect = true + w.WriteHeader(200) + return + } + + if !connect { + t.Log("connect not recieved") + http.Error(w, "connect not recieved", 405) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) + + cstDialer.Proxy = http.ProxyFromEnvironment +} + +func TestProxyAuthorizationDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.URL) + surl.User = url.UserPassword("username", "password") + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + proxyAuth := r.Header.Get("Proxy-Authorization") + expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) + if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { + connect = true + w.WriteHeader(200) + return + } + + if !connect { + t.Log("connect with proxy authorization not recieved") + http.Error(w, "connect with proxy authorization not recieved", 405) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) + + cstDialer.Proxy = http.ProxyFromEnvironment +} + +func TestDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDialTLS(t *testing.T) { + s := newTLSServer(t) + defer s.Close() + + certs := x509.NewCertPool() + for _, c := range s.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: certs} + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func xTestDialTLSBadCert(t *testing.T) { + // This test is deactivated because of noisy logging from the net/http package. + s := newTLSServer(t) + defer s.Close() + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } +} + +func TestDialTLSNoVerify(t *testing.T) { + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDialTimeout(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialer + d.HandshakeTimeout = -1 + ws, _, err := d.Dial(s.URL, nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } +} + +func TestDialBadScheme(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, _, err := cstDialer.Dial(s.Server.URL, nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } +} + +func TestDialBadOrigin(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden) + } +} + +func TestDialBadHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + for _, k := range []string{"Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol"} { + h := http.Header{} + h.Set(k, "bad") + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Errorf("Dial with header %s returned nil", k) + } + } +} + +func TestBadMethod(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + resp, err := http.PostForm(s.URL, url.Values{}) + if err != nil { + t.Fatalf("PostForm returned error %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } +} + +func TestHandshake(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + var sessionID string + for _, c := range resp.Cookies() { + if c.Name == "sessionID" { + sessionID = c.Value + } + } + if sessionID != "1234" { + t.Error("Set-Cookie not received from the server.") + } + + if ws.Subprotocol() != "p1" { + t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol()) + } + sendRecv(t, ws) +} + +func TestRespOnBadHandshake(t *testing.T) { + const expectedStatus = http.StatusGone + const expectedBody = "This is the response body." + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + io.WriteString(w, expectedBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } + + p, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } + + if string(p) != expectedBody { + t.Errorf("resp.Body=%s, want %s", p, expectedBody) + } +} + +// TestHostHeader confirms that the host header provided in the call to Dial is +// sent to the server. +func TestHostHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + specifiedHost := make(chan string, 1) + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + specifiedHost <- r.Host + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + if gotHost := <-specifiedHost; gotHost != "testhost" { + t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) + } + + sendRecv(t, ws) +} diff --git a/vendor/github.com/gorilla/websocket/client_test.go b/vendor/github.com/gorilla/websocket/client_test.go new file mode 100644 index 00000000..7d2b0844 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client_test.go @@ -0,0 +1,72 @@ +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "net/url" + "reflect" + "testing" +) + +var parseURLTests = []struct { + s string + u *url.URL + rui string +}{ + {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, + {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, + {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"}, + {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"}, + {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"}, + {"ss://example.com/a/b", nil, ""}, + {"ws://webmaster@example.com/", nil, ""}, + {"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"}, + {"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"}, +} + +func TestParseURL(t *testing.T) { + for _, tt := range parseURLTests { + u, err := parseURL(tt.s) + if tt.u != nil && err != nil { + t.Errorf("parseURL(%q) returned error %v", tt.s, err) + continue + } + if tt.u == nil { + if err == nil { + t.Errorf("parseURL(%q) did not return error", tt.s) + } + continue + } + if !reflect.DeepEqual(u, tt.u) { + t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u) + continue + } + if u.RequestURI() != tt.rui { + t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui) + } + } +} + +var hostPortNoPortTests = []struct { + u *url.URL + hostPort, hostNoPort string +}{ + {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"}, + {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"}, +} + +func TestHostPortNoPort(t *testing.T) { + for _, tt := range hostPortNoPortTests { + hostPort, hostNoPort := hostPortNoPort(tt.u) + if hostPort != tt.hostPort { + t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort) + } + if hostNoPort != tt.hostNoPort { + t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort) + } + } +} diff --git a/vendor/github.com/gorilla/websocket/compression.go b/vendor/github.com/gorilla/websocket/compression.go new file mode 100644 index 00000000..e2ac7617 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/compression.go @@ -0,0 +1,85 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "compress/flate" + "errors" + "io" + "strings" +) + +func decompressNoContextTakeover(r io.Reader) io.Reader { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) +} + +func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { + tw := &truncWriter{w: w} + fw, err := flate.NewWriter(tw, 3) + return &flateWrapper{fw: fw, tw: tw}, err +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWrapper struct { + fw *flate.Writer + tw *truncWriter +} + +func (w *flateWrapper) Write(p []byte) (int, error) { + return w.fw.Write(p) +} + +func (w *flateWrapper) Close() error { + err1 := w.fw.Flush() + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} diff --git a/vendor/github.com/gorilla/websocket/compression_test.go b/vendor/github.com/gorilla/websocket/compression_test.go new file mode 100644 index 00000000..cad70fb5 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/compression_test.go @@ -0,0 +1,31 @@ +package websocket + +import ( + "bytes" + "io" + "testing" +) + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +func TestTruncWriter(t *testing.T) { + const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" + for n := 1; n <= 10; n++ { + var b bytes.Buffer + w := &truncWriter{w: nopCloser{&b}} + p := []byte(data) + for len(p) > 0 { + m := len(p) + if m > n { + m = n + } + w.Write(p[:m]) + p = p[m:] + } + if b.String() != data[:len(data)-len(w.p)] { + t.Errorf("%d: %q", n, b.String()) + } + } +} diff --git a/vendor/github.com/gorilla/websocket/conn.go b/vendor/github.com/gorilla/websocket/conn.go new file mode 100644 index 00000000..eb4334e7 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn.go @@ -0,0 +1,994 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strconv" + "time" + "unicode/utf8" +) + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents close frame. +type CloseError struct { + + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("websocket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("websocket: bad write message type") + errWriteClosed = errors.New("websocket: write closed") + errInvalidControlFrame = errors.New("websocket: invalid control frame") +) + +func hideTempErr(err error) error { + if e, ok := err.(net.Error); ok && e.Temporary() { + err = &netError{msg: e.Error(), timeout: e.Timeout()} + } + return err +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + +// Conn represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + subprotocol string + + // Write fields + mu chan bool // used as mutex to protect write to conn and closeSent + closeSent bool // whether close message was sent + writeErr error + writeBuf []byte // frame is constructed in this buffer. + writePos int // end of data in writeBuf. + writeFrameType int // type of the current frame. + writeDeadline time.Time + messageWriter *messageWriter // the current low-level message writer + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + enableWriteCompression bool + writeCompress bool // whether next call to flushFrame should set RSV1 + newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) + + // Read fields + readErr error + br *bufio.Reader + readRemaining int64 // bytes remaining in current frame. + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.Reader +} + +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { + mu := make(chan bool, 1) + mu <- true + + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } + if readBufferSize < maxControlFramePayloadSize { + readBufferSize = maxControlFramePayloadSize + } + if writeBufferSize == 0 { + writeBufferSize = defaultWriteBufferSize + } + + c := &Conn{ + isServer: isServer, + br: bufio.NewReaderSize(conn, readBufferSize), + conn: conn, + mu: mu, + readFinal: true, + writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), + writeFrameType: noFrame, + writePos: maxFrameHeaderSize, + enableWriteCompression: true, + } + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// Subprotocol returns the negotiated protocol for the connection. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +// Close closes the underlying network connection without sending or waiting for a close frame. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { + <-c.mu + defer func() { c.mu <- true }() + + if c.closeSent { + return ErrCloseSent + } else if frameType == CloseMessage { + c.closeSent = true + } + + c.conn.SetWriteDeadline(deadline) + for _, buf := range bufs { + if len(buf) > 0 { + n, err := c.conn.Write(buf) + if n != len(buf) { + // Close on partial write. + c.conn.Close() + } + if err != nil { + return err + } + } + } + return nil +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + d := time.Hour * 1000 + if !deadline.IsZero() { + d = deadline.Sub(time.Now()) + if d < 0 { + return errWriteTimeout + } + } + + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + defer func() { c.mu <- true }() + + if c.closeSent { + return ErrCloseSent + } else if messageType == CloseMessage { + c.closeSent = true + } + + c.conn.SetWriteDeadline(deadline) + n, err := c.conn.Write(buf) + if n != 0 && n != len(buf) { + c.conn.Close() + } + return hideTempErr(err) +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + if c.writeErr != nil { + return nil, c.writeErr + } + + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + err := c.writer.Close() + if err != nil { + return nil, err + } + } + + if !isControl(messageType) && !isData(messageType) { + return nil, errBadWriteOpCode + } + + c.writeFrameType = messageType + c.messageWriter = &messageWriter{c} + + var w io.WriteCloser = c.messageWriter + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + c.writeCompress = true + var err error + w, err = c.newCompressionWriter(w) + if err != nil { + c.writer.Close() + return nil, err + } + } + + return w, nil +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (c *Conn) flushFrame(final bool, extra []byte) error { + length := c.writePos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(c.writeFrameType) && + (!final || length > maxControlFramePayloadSize) { + c.messageWriter = nil + c.writer = nil + c.writeFrameType = noFrame + c.writePos = maxFrameHeaderSize + return errInvalidControlFrame + } + + b0 := byte(c.writeFrameType) + if final { + b0 |= finalBit + } + if c.writeCompress { + b0 |= rsv1Bit + } + c.writeCompress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) + if len(extra) > 0 { + c.writeErr = errors.New("websocket: internal error, extra used in client mode") + return c.writeErr + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + // Setup for next frame. + c.writePos = maxFrameHeaderSize + c.writeFrameType = continuationFrame + if final { + c.messageWriter = nil + c.writer = nil + c.writeFrameType = noFrame + } + return c.writeErr +} + +type messageWriter struct{ c *Conn } + +func (w *messageWriter) err() error { + c := w.c + if c.messageWriter != w { + return errWriteClosed + } + if c.writeErr != nil { + return c.writeErr + } + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.c.writePos + if n <= 0 { + if err := w.c.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.c.writePos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if err := w.err(); err != nil { + return 0, err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.c.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.c.writePos:], p[:n]) + w.c.writePos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if err := w.err(); err != nil { + return 0, err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.c.writePos:], p[:n]) + w.c.writePos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if err := w.err(); err != nil { + return 0, err + } + for { + if w.c.writePos == len(w.c.writeBuf) { + err = w.c.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.c.writePos:]) + w.c.writePos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if err := w.err(); err != nil { + return err + } + return w.c.flushFrame(true, nil) +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, ok := w.(*messageWriter); ok && c.isServer { + // Optimize write as a single frame. + n := copy(c.writeBuf[c.writePos:], data) + c.writePos += n + data = data[n:] + err = c.flushFrame(true, data) + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + final := p[0]&finalBit != 0 + frameType := int(p[0] & 0xf) + mask := p[1]&maskBit != 0 + c.readRemaining = int64(p[1] & 0x7f) + + c.readDecompress = false + if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + c.readDecompress = true + p[0] &^= rsv1Bit + } + + if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + return noFrame, c.handleProtocolError("control frame length > 125") + } + if !final { + return noFrame, c.handleProtocolError("control frame not final") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + return noFrame, c.handleProtocolError("message start before final message frame") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + return noFrame, c.handleProtocolError("continuation after final message frame") + } + c.readFinal = final + default: + return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + } + + // 3. Read and parse frame length. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint16(p)) + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint64(p)) + } + + // 4. Handle frame masking. + + if mask != c.isServer { + return noFrame, c.handleProtocolError("incorrect mask flag") + } + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + if c.readLimit > 0 && c.readLength > c.readLimit { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + c.readRemaining = 0 + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + echoMessage := []byte{} + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + echoMessage = payload[:2] + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("invalid close code") + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait)) + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = hideTempErr(err) + break + } + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + var r io.Reader = c.messageReader + if c.readDecompress { + r = c.newDecompressionReader(r) + } + return frameType, r, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *Conn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.br.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + c.readRemaining -= int64(n) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = hideTempErr(err) + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = ioutil.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size for a message read from the peer. If a +// message exceeds the limit, the connection sends a close frame to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING frame application data. The default +// ping handler sends a pong to the peer. +func (c *Conn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + if err == ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG frame application data. The default +// pong handler does nothing. +func (c *Conn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +func (c *Conn) UnderlyingConn() net.Conn { + return c.conn +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +func FormatCloseMessage(closeCode int, text string) []byte { + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/vendor/github.com/gorilla/websocket/conn_read.go b/vendor/github.com/gorilla/websocket/conn_read.go new file mode 100644 index 00000000..1ea15059 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_read.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.5 + +package websocket + +import "io" + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.br.Discard(len(p)) + return p, err +} diff --git a/vendor/github.com/gorilla/websocket/conn_read_legacy.go b/vendor/github.com/gorilla/websocket/conn_read_legacy.go new file mode 100644 index 00000000..018541cf --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_read_legacy.go @@ -0,0 +1,21 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.5 + +package websocket + +import "io" + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + if len(p) > 0 { + // advance over the bytes just read + io.ReadFull(c.br, p) + } + return p, err +} diff --git a/vendor/github.com/gorilla/websocket/conn_test.go b/vendor/github.com/gorilla/websocket/conn_test.go new file mode 100644 index 00000000..0243c115 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_test.go @@ -0,0 +1,402 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "reflect" + "testing" + "testing/iotest" + "time" +) + +var _ net.Error = errWriteTimeout + +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return nil } +func (c fakeNetConn) RemoteAddr() net.Addr { return nil } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestFraming(t *testing.T) { + frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} + var readChunkers = []struct { + name string + f func(io.Reader) io.Reader + }{ + {"half", iotest.HalfReader}, + {"one", iotest.OneByteReader}, + {"asis", func(r io.Reader) io.Reader { return r }}, + } + + writeBuf := make([]byte, 65537) + for i := range writeBuf { + writeBuf[i] = byte(i) + } + + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { + + var connBuf bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) + rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + + for _, n := range frameSizes { + for _, iocopy := range []bool{true, false} { + name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + var nn int + if iocopy { + var n64 int64 + n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n])) + nn = int(n64) + } else { + nn, err = w.Write(writeBuf[:n]) + } + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + rbuf, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } + + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } + + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } + } + } + } + } + } +} + +func TestControl(t *testing.T) { + const message = "this is a ping/pong messsage" + for _, isServer := range []bool{true, false} { + for _, isWriteControl := range []bool{true, false} { + name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) + var connBuf bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) + rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024) + if isWriteControl { + wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + } else { + w, err := wc.NextWriter(PongMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + if _, err := w.Write([]byte(message)); err != nil { + t.Errorf("%s: w.Write() returned %v", name, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + var actualMessage string + rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) + rc.NextReader() + if actualMessage != message { + t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) + continue + } + } + } + } +} + +func TestCloseBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) + } + _, _, err = rc.NextReader() + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) + } +} + +func TestEOFWithinFrame(t *testing.T) { + const bufSize = 64 + + for n := 0; ; n++ { + var b bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) + rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize)) + w.Close() + + if n >= b.Len() { + break + } + b.Truncate(n) + + op, r, err := rc.NextReader() + if err == errUnexpectedEOF { + continue + } + if op != BinaryMessage || err != nil { + t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) + } + } +} + +func TestEOFBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) + } +} + +func TestReadLimit(t *testing.T) { + + const readLimit = 512 + message := make([]byte, readLimit+1) + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + rc.SetReadLimit(readLimit) + + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + w.Write(message[:readLimit-1]) + wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + w.Write(message[:1]) + w.Close() + + // Send message larger than the limit. + wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } +} + +func TestUnderlyingConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024) + ul := c.UnderlyingConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestBufioReadBytes(t *testing.T) { + + // Test calling bufio.ReadBytes for value longer than read buffer size. + + m := make([]byte, 512) + m[len(m)-1] = '\n' + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(m) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + + br := bufio.NewReader(r) + p, err := br.ReadBytes('\n') + if err != nil { + t.Fatalf("ReadBytes() returned %v", err) + } + if len(p) != len(m) { + t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m)) + } +} + +var closeErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestCloseError(t *testing.T) { + for _, tt := range closeErrorTests { + ok := IsCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +var unexpectedCloseErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestUnexpectedCloseErrors(t *testing.T) { + for _, tt := range unexpectedCloseErrorTests { + ok := IsUnexpectedCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +type blockingWriter struct { + c1, c2 chan struct{} +} + +func (w blockingWriter) Write(p []byte) (int, error) { + // Allow main to continue + close(w.c1) + // Wait for panic in main + <-w.c2 + return len(p), nil +} + +func TestConcurrentWritePanic(t *testing.T) { + w := blockingWriter{make(chan struct{}), make(chan struct{})} + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + go func() { + c.WriteMessage(TextMessage, []byte{}) + }() + + // wait for goroutine to block in write. + <-w.c1 + + defer func() { + close(w.c2) + if v := recover(); v != nil { + return + } + }() + + c.WriteMessage(TextMessage, []byte{}) + t.Fatal("should not get here") +} + +type failingReader struct{} + +func (r failingReader) Read(p []byte) (int, error) { + return 0, io.EOF +} + +func TestFailedConnectionReadPanic(t *testing.T) { + c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024) + + defer func() { + if v := recover(); v != nil { + return + } + }() + + for i := 0; i < 20000; i++ { + c.ReadMessage() + } + t.Fatal("should not get here") +} diff --git a/vendor/github.com/gorilla/websocket/doc.go b/vendor/github.com/gorilla/websocket/doc.go new file mode 100644 index 00000000..c901a7a9 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/doc.go @@ -0,0 +1,152 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements the WebSocket protocol defined in RFC 6455. +// +// Overview +// +// The Conn type represents a WebSocket connection. A server application uses +// the Upgrade function from an Upgrader object with a HTTP request handler +// to get a pointer to a Conn: +// +// var upgrader = websocket.Upgrader{ +// ReadBufferSize: 1024, +// WriteBufferSize: 1024, +// } +// +// func handler(w http.ResponseWriter, r *http.Request) { +// conn, err := upgrader.Upgrade(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// ... Use conn to send and receive messages. +// } +// +// Call the connection's WriteMessage and ReadMessage methods to send and +// receive messages as a slice of bytes. This snippet of code shows how to echo +// messages using these methods: +// +// for { +// messageType, p, err := conn.ReadMessage() +// if err != nil { +// return +// } +// if err = conn.WriteMessage(messageType, p); err != nil { +// return err +// } +// } +// +// In above snippet of code, p is a []byte and messageType is an int with value +// websocket.BinaryMessage or websocket.TextMessage. +// +// An application can also send and receive messages using the io.WriteCloser +// and io.Reader interfaces. To send a message, call the connection NextWriter +// method to get an io.WriteCloser, write the message to the writer and close +// the writer when done. To receive a message, call the connection NextReader +// method to get an io.Reader and read until io.EOF is returned. This snippet +// shows how to echo messages using the NextWriter and NextReader methods: +// +// for { +// messageType, r, err := conn.NextReader() +// if err != nil { +// return +// } +// w, err := conn.NextWriter(messageType) +// if err != nil { +// return err +// } +// if _, err := io.Copy(w, r); err != nil { +// return err +// } +// if err := w.Close(); err != nil { +// return err +// } +// } +// +// Data Messages +// +// The WebSocket protocol distinguishes between text and binary data messages. +// Text messages are interpreted as UTF-8 encoded text. The interpretation of +// binary messages is left to the application. +// +// This package uses the TextMessage and BinaryMessage integer constants to +// identify the two data message types. The ReadMessage and NextReader methods +// return the type of the received message. The messageType argument to the +// WriteMessage and NextWriter methods specifies the type of a sent message. +// +// It is the application's responsibility to ensure that text messages are +// valid UTF-8 encoded text. +// +// Control Messages +// +// The WebSocket protocol defines three types of control messages: close, ping +// and pong. Call the connection WriteControl, WriteMessage or NextWriter +// methods to send a control message to the peer. +// +// Connections handle received close messages by sending a close message to the +// peer and returning a *CloseError from the the NextReader, ReadMessage or the +// message Read method. +// +// Connections handle received ping and pong messages by invoking callback +// functions set with SetPingHandler and SetPongHandler methods. The callback +// functions are called from the NextReader, ReadMessage and the message Read +// methods. +// +// The default ping handler sends a pong to the peer. The application's reading +// goroutine can block for a short time while the handler writes the pong data +// to the connection. +// +// The application must read the connection to process ping, pong and close +// messages sent from the peer. If the application is not otherwise interested +// in messages from the peer, then the application should start a goroutine to +// read and discard messages from the peer. A simple example is: +// +// func readLoop(c *websocket.Conn) { +// for { +// if _, _, err := c.NextReader(); err != nil { +// c.Close() +// break +// } +// } +// } +// +// Concurrency +// +// Connections support one concurrent reader and one concurrent writer. +// +// Applications are responsible for ensuring that no more than one goroutine +// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, +// WriteJSON) concurrently and that no more than one goroutine calls the read +// methods (NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, +// SetPingHandler) concurrently. +// +// The Close and WriteControl methods can be called concurrently with all other +// methods. +// +// Origin Considerations +// +// Web browsers allow Javascript applications to open a WebSocket connection to +// any host. It's up to the server to enforce an origin policy using the Origin +// request header sent by the browser. +// +// The Upgrader calls the function specified in the CheckOrigin field to check +// the origin. If the CheckOrigin function returns false, then the Upgrade +// method fails the WebSocket handshake with HTTP status 403. +// +// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail +// the handshake if the Origin request header is present and not equal to the +// Host request header. +// +// An application can allow connections from any origin by specifying a +// function that always returns true: +// +// var upgrader = websocket.Upgrader{ +// CheckOrigin: func(r *http.Request) bool { return true }, +// } +// +// The deprecated Upgrade function does not enforce an origin policy. It's the +// application's responsibility to check the Origin header before calling +// Upgrade. +package websocket diff --git a/vendor/github.com/gorilla/websocket/example_test.go b/vendor/github.com/gorilla/websocket/example_test.go new file mode 100644 index 00000000..96449eac --- /dev/null +++ b/vendor/github.com/gorilla/websocket/example_test.go @@ -0,0 +1,46 @@ +// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket_test + +import ( + "log" + "net/http" + "testing" + + "github.com/gorilla/websocket" +) + +var ( + c *websocket.Conn + req *http.Request +) + +// The websocket.IsUnexpectedCloseError function is useful for identifying +// application and protocol errors. +// +// This server application works with a client application running in the +// browser. The client application does not explicitly close the websocket. The +// only expected close message from the client has the code +// websocket.CloseGoingAway. All other other close messages are likely the +// result of an application or protocol error and are logged to aid debugging. +func ExampleIsUnexpectedCloseError() { + + for { + messageType, p, err := c.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent")) + } + return + } + processMesage(messageType, p) + } +} + +func processMesage(mt int, p []byte) {} + +// TestX prevents godoc from showing this entire file in the example. Remove +// this function when a second example is added. +func TestX(t *testing.T) {} diff --git a/vendor/github.com/gorilla/websocket/json.go b/vendor/github.com/gorilla/websocket/json.go new file mode 100644 index 00000000..4f0e3687 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/json.go @@ -0,0 +1,55 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "encoding/json" + "io" +) + +// WriteJSON is deprecated, use c.WriteJSON instead. +func WriteJSON(c *Conn, v interface{}) error { + return c.WriteJSON(v) +} + +// WriteJSON writes the JSON encoding of v to the connection. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *Conn) WriteJSON(v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON is deprecated, use c.ReadJSON instead. +func ReadJSON(c *Conn, v interface{}) error { + return c.ReadJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *Conn) ReadJSON(v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} diff --git a/vendor/github.com/gorilla/websocket/json_test.go b/vendor/github.com/gorilla/websocket/json_test.go new file mode 100644 index 00000000..61100e48 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/json_test.go @@ -0,0 +1,119 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "encoding/json" + "io" + "reflect" + "testing" +) + +func TestJSON(t *testing.T) { + var buf bytes.Buffer + c := fakeNetConn{&buf, &buf} + wc := newConn(c, true, 1024, 1024) + rc := newConn(c, false, 1024, 1024) + + var actual, expect struct { + A int + B string + } + expect.A = 1 + expect.B = "hello" + + if err := wc.WriteJSON(&expect); err != nil { + t.Fatal("write", err) + } + + if err := rc.ReadJSON(&actual); err != nil { + t.Fatal("read", err) + } + + if !reflect.DeepEqual(&actual, &expect) { + t.Fatal("equal", actual, expect) + } +} + +func TestPartialJSONRead(t *testing.T) { + var buf bytes.Buffer + c := fakeNetConn{&buf, &buf} + wc := newConn(c, true, 1024, 1024) + rc := newConn(c, false, 1024, 1024) + + var v struct { + A int + B string + } + v.A = 1 + v.B = "hello" + + messageCount := 0 + + // Partial JSON values. + + data, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + for i := len(data) - 1; i >= 0; i-- { + if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { + t.Fatal(err) + } + messageCount++ + } + + // Whitespace. + + if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { + t.Fatal(err) + } + messageCount++ + + // Close. + + if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { + t.Fatal(err) + } + + for i := 0; i < messageCount; i++ { + err := rc.ReadJSON(&v) + if err != io.ErrUnexpectedEOF { + t.Error("read", i, err) + } + } + + err = rc.ReadJSON(&v) + if _, ok := err.(*CloseError); !ok { + t.Error("final", err) + } +} + +func TestDeprecatedJSON(t *testing.T) { + var buf bytes.Buffer + c := fakeNetConn{&buf, &buf} + wc := newConn(c, true, 1024, 1024) + rc := newConn(c, false, 1024, 1024) + + var actual, expect struct { + A int + B string + } + expect.A = 1 + expect.B = "hello" + + if err := WriteJSON(wc, &expect); err != nil { + t.Fatal("write", err) + } + + if err := ReadJSON(rc, &actual); err != nil { + t.Fatal("read", err) + } + + if !reflect.DeepEqual(&actual, &expect) { + t.Fatal("equal", actual, expect) + } +} diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go new file mode 100644 index 00000000..8402d20b --- /dev/null +++ b/vendor/github.com/gorilla/websocket/server.go @@ -0,0 +1,261 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +// Upgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +type Upgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer + // size is zero, then a default value of 4096 is used. The I/O buffer sizes + // do not limit the size of the messages that can be sent or received. + ReadBufferSize, WriteBufferSize int + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is set, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(w http.ResponseWriter, r *http.Request, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, the host in the Origin header must not be set or + // must match the host of the request. + CheckOrigin func(r *http.Request) bool +} + +func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(w, r, status, err) + } else { + w.Header().Set("Sec-Websocket-Version", "13") + http.Error(w, http.StatusText(status), status) + } + return nil, err +} + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +func checkSameOrigin(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + return u.Host == r.Host +} + +func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { + if u.Subprotocols != nil { + clientProtocols := Subprotocols(r) + for _, serverProtocol := range u.Subprotocols { + for _, clientProtocol := range clientProtocols { + if clientProtocol == serverProtocol { + return clientProtocol + } + } + } + } else if responseHeader != nil { + return responseHeader.Get("Sec-Websocket-Protocol") + } + return "" +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// application negotiated subprotocol (Sec-Websocket-Protocol). +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + if r.Method != "GET" { + return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") + } + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") + } + + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'") + } + + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'") + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(r) { + return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed") + } + + challengeKey := r.Header.Get("Sec-Websocket-Key") + if challengeKey == "" { + return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank") + } + + subprotocol := u.selectSubprotocol(r, responseHeader) + + var ( + netConn net.Conn + br *bufio.Reader + err error + ) + + h, ok := w.(http.Hijacker) + if !ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") + } + var rw *bufio.ReadWriter + netConn, rw, err = h.Hijack() + if err != nil { + return u.returnError(w, r, http.StatusInternalServerError, err.Error()) + } + br = rw.Reader + + if br.Buffered() > 0 { + netConn.Close() + return nil, errors.New("websocket: client sent data before handshake is complete") + } + + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) + c.subprotocol = subprotocol + + p := c.writeBuf[:0] + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) + p = append(p, computeAcceptKey(challengeKey)...) + p = append(p, "\r\n"...) + if c.subprotocol != "" { + p = append(p, "Sec-Websocket-Protocol: "...) + p = append(p, c.subprotocol...) + p = append(p, "\r\n"...) + } + for k, vs := range responseHeader { + if k == "Sec-Websocket-Protocol" { + continue + } + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + // Clear deadlines set by HTTP server. + netConn.SetDeadline(time.Time{}) + + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + } + if _, err = netConn.Write(p); err != nil { + netConn.Close() + return nil, err + } + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Time{}) + } + + return c, nil +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// This function is deprecated, use websocket.Upgrader instead. +// +// The application is responsible for checking the request origin before +// calling Upgrade. An example implementation of the same origin policy is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", 403) +// return +// } +// +// If the endpoint supports subprotocols, then the application is responsible +// for negotiating the protocol used on the connection. Use the Subprotocols() +// function to get the subprotocols requested by the client. Use the +// Sec-Websocket-Protocol response header to specify the subprotocol selected +// by the application. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// negotiated subprotocol (Sec-Websocket-Protocol). +// +// The connection buffers IO to the underlying network connection. The +// readBufSize and writeBufSize parameters specify the size of the buffers to +// use. Messages can be larger than the buffers. +// +// If the request is not a valid WebSocket handshake, then Upgrade returns an +// error of type HandshakeError. Applications should handle this error by +// replying to the client with an HTTP error response. +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} + u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { + // don't return errors to maintain backwards compatibility + } + u.CheckOrigin = func(r *http.Request) bool { + // allow all connections by default + return true + } + return u.Upgrade(w, r, responseHeader) +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-Websocket-Protocol header. +func Subprotocols(r *http.Request) []string { + h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) + if h == "" { + return nil + } + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(r *http.Request) bool { + return tokenListContainsValue(r.Header, "Connection", "upgrade") && + tokenListContainsValue(r.Header, "Upgrade", "websocket") +} diff --git a/vendor/github.com/gorilla/websocket/server_test.go b/vendor/github.com/gorilla/websocket/server_test.go new file mode 100644 index 00000000..0a28141d --- /dev/null +++ b/vendor/github.com/gorilla/websocket/server_test.go @@ -0,0 +1,51 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "net/http" + "reflect" + "testing" +) + +var subprotocolTests = []struct { + h string + protocols []string +}{ + {"", nil}, + {"foo", []string{"foo"}}, + {"foo,bar", []string{"foo", "bar"}}, + {"foo, bar", []string{"foo", "bar"}}, + {" foo, bar", []string{"foo", "bar"}}, + {" foo, bar ", []string{"foo", "bar"}}, +} + +func TestSubprotocols(t *testing.T) { + for _, st := range subprotocolTests { + r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} + protocols := Subprotocols(&r) + if !reflect.DeepEqual(st.protocols, protocols) { + t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols) + } + } +} + +var isWebSocketUpgradeTests = []struct { + ok bool + h http.Header +}{ + {false, http.Header{"Upgrade": {"websocket"}}}, + {false, http.Header{"Connection": {"upgrade"}}}, + {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, +} + +func TestIsWebSocketUpgrade(t *testing.T) { + for _, tt := range isWebSocketUpgradeTests { + ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) + if tt.ok != ok { + t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) + } + } +} diff --git a/vendor/github.com/gorilla/websocket/util.go b/vendor/github.com/gorilla/websocket/util.go new file mode 100644 index 00000000..9a4908df --- /dev/null +++ b/vendor/github.com/gorilla/websocket/util.go @@ -0,0 +1,214 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "strings" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// Octet types from RFC 2616. +var octetTypes [256]byte + +const ( + isTokenOctet = 1 << iota + isSpaceOctet +) + +func init() { + // From RFC 2616 + // + // OCTET = + // CHAR = + // CTL = + // CR = + // LF = + // SP = + // HT = + // <"> = + // CRLF = CR LF + // LWS = [CRLF] 1*( SP | HT ) + // TEXT = + // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT + // token = 1* + // qdtext = > + + for c := 0; c < 256; c++ { + var t byte + isCtl := c <= 31 || c == 127 + isChar := 0 <= c && c <= 127 + isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0 + if strings.IndexRune(" \t\r\n", rune(c)) >= 0 { + t |= isSpaceOctet + } + if isChar && !isCtl && !isSeparator { + t |= isTokenOctet + } + octetTypes[c] = t + } +} + +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isSpaceOctet == 0 { + break + } + } + return s[i:] +} + +func nextToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isTokenOctet == 0 { + break + } + } + return s[:i], s[i:] +} + +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j += 1 + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j += 1 + } + } + return "", "" + } + } + return "", "" +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains token. +func tokenListContainsValue(header http.Header, name string, value string) bool { +headers: + for _, s := range header[name] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + continue headers + } + if strings.EqualFold(t, value) { + return true + } + if s == "" { + continue headers + } + s = s[1:] + } + } + return false +} + +// parseExtensiosn parses WebSocket extensions from a header. +func parseExtensions(header http.Header) []map[string]string { + + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} diff --git a/vendor/github.com/gorilla/websocket/util_test.go b/vendor/github.com/gorilla/websocket/util_test.go new file mode 100644 index 00000000..610e613c --- /dev/null +++ b/vendor/github.com/gorilla/websocket/util_test.go @@ -0,0 +1,74 @@ +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "net/http" + "reflect" + "testing" +) + +var tokenListContainsValueTests = []struct { + value string + ok bool +}{ + {"WebSocket", true}, + {"WEBSOCKET", true}, + {"websocket", true}, + {"websockets", false}, + {"x websocket", false}, + {"websocket x", false}, + {"other,websocket,more", true}, + {"other, websocket, more", true}, +} + +func TestTokenListContainsValue(t *testing.T) { + for _, tt := range tokenListContainsValueTests { + h := http.Header{"Upgrade": {tt.value}} + ok := tokenListContainsValue(h, "Upgrade", "websocket") + if ok != tt.ok { + t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok) + } + } +} + +var parseExtensionTests = []struct { + value string + extensions []map[string]string +}{ + {`foo`, []map[string]string{map[string]string{"": "foo"}}}, + {`foo, bar; baz=2`, []map[string]string{ + map[string]string{"": "foo"}, + map[string]string{"": "bar", "baz": "2"}}}, + {`foo; bar="b,a;z"`, []map[string]string{ + map[string]string{"": "foo", "bar": "b,a;z"}}}, + {`foo , bar; baz = 2`, []map[string]string{ + map[string]string{"": "foo"}, + map[string]string{"": "bar", "baz": "2"}}}, + {`foo, bar; baz=2 junk`, []map[string]string{ + map[string]string{"": "foo"}}}, + {`foo junk, bar; baz=2 junk`, nil}, + {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ + map[string]string{"": "mux", "max-channels": "4", "flow-control": ""}, + map[string]string{"": "deflate-stream"}}}, + {`permessage-foo; x="10"`, []map[string]string{ + map[string]string{"": "permessage-foo", "x": "10"}}}, + {`permessage-foo; use_y, permessage-foo`, []map[string]string{ + map[string]string{"": "permessage-foo", "use_y": ""}, + map[string]string{"": "permessage-foo"}}}, + {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ + map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, + map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}}, +} + +func TestParseExtensions(t *testing.T) { + for _, tt := range parseExtensionTests { + h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} + extensions := parseExtensions(h) + if !reflect.DeepEqual(extensions, tt.extensions) { + t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions) + } + } +} diff --git a/vendor/vendor.json b/vendor/vendor.json new file mode 100644 index 00000000..6402559d --- /dev/null +++ b/vendor/vendor.json @@ -0,0 +1,13 @@ +{ + "comment": "", + "ignore": "", + "package": [ + { + "checksumSHA1": "Tknk9q8ncICzdw+etBD6UViMhvc=", + "path": "github.com/gorilla/websocket", + "revision": "2d1e4548da234d9cb742cc3628556fef86aafbac", + "revisionTime": "2016-09-12T15:30:41Z" + } + ], + "rootPath": "github.com/elazarl/goproxy" +} From e0980d327f9dbb54f0990d1ffda8ffd5e00f2172 Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Wed, 14 Sep 2016 19:32:03 +0300 Subject: [PATCH 2/8] WIP --- .gitignore | 1 + ctx.go | 17 +-- examples/goproxy-websocket/main.go | 31 ++++++ https.go | 6 +- proxy.go | 159 +++++++++++++++++++++-------- websocket.go | 38 +++++++ 6 files changed, 201 insertions(+), 51 deletions(-) create mode 100644 examples/goproxy-websocket/main.go create mode 100644 websocket.go diff --git a/.gitignore b/.gitignore index 1005f6f1..5861daaf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +/.idea/ bin *.swp diff --git a/ctx.go b/ctx.go index 95bfd800..e875c4d7 100644 --- a/ctx.go +++ b/ctx.go @@ -9,18 +9,19 @@ import ( // every user function. Also used as a logger. type ProxyCtx struct { // Will contain the client request from the proxy - Req *http.Request + Req *http.Request // Will contain the remote server's response (if available. nil if the request wasn't send yet) Resp *http.Response RoundTripper RoundTripper // will contain the recent error that occured while trying to send receive or parse traffic - Error error + Error error // A handle for the user to keep data in the context, from the call of ReqHandler to the // call of RespHandler - UserData interface{} + UserData interface{} // Will connect a request to a response - Session int64 - proxy *ProxyHttpServer + Session int64 + Websocket bool + proxy *ProxyHttpServer } type RoundTripper interface { @@ -41,7 +42,7 @@ func (ctx *ProxyCtx) RoundTrip(req *http.Request) (*http.Response, error) { } func (ctx *ProxyCtx) printf(msg string, argv ...interface{}) { - ctx.proxy.Logger.Printf("[%03d] "+msg+"\n", append([]interface{}{ctx.Session & 0xFF}, argv...)...) + ctx.proxy.Logger.Printf("[%03d] " + msg + "\n", append([]interface{}{ctx.Session & 0xFF}, argv...)...) } // Logf prints a message to the proxy's log. Should be used in a ProxyHttpServer's filter @@ -54,7 +55,7 @@ func (ctx *ProxyCtx) printf(msg string, argv ...interface{}) { // }) func (ctx *ProxyCtx) Logf(msg string, argv ...interface{}) { if ctx.proxy.Verbose { - ctx.printf("INFO: "+msg, argv...) + ctx.printf("INFO: " + msg, argv...) } } @@ -70,7 +71,7 @@ func (ctx *ProxyCtx) Logf(msg string, argv ...interface{}) { // return r, nil // }) func (ctx *ProxyCtx) Warnf(msg string, argv ...interface{}) { - ctx.printf("WARN: "+msg, argv...) + ctx.printf("WARN: " + msg, argv...) } var charsetFinder = regexp.MustCompile("charset=([^ ;]*)") diff --git a/examples/goproxy-websocket/main.go b/examples/goproxy-websocket/main.go new file mode 100644 index 00000000..83149ed6 --- /dev/null +++ b/examples/goproxy-websocket/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "github.com/elazarl/goproxy" + "log" + "net/http" + "regexp" +) + +func main() { + // Init + https := regexp.MustCompile("^.*:(443|8443)$") + + proxy := goproxy.NewProxyHttpServer() + proxy.Verbose = true + + // MitM + proxy.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { + if ctx.Req.Header.Get("Connection") == "Upgrade" { + return goproxy.RejectConnect, host + } + + if https.MatchString(host) { + return goproxy.MitmConnect, host + } else { + return goproxy.HTTPMitmConnect, host + } + }) + + log.Fatal(http.ListenAndServe(":8888", proxy)) +} diff --git a/https.go b/https.go index 4ade3bde..c1b4111c 100644 --- a/https.go +++ b/https.go @@ -89,6 +89,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request break } } + switch todo.Action { case ConnectAccept: if !hasPort.MatchString(host) { @@ -116,7 +117,6 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request wg.Wait() proxyClient.Close() targetSiteCon.Close() - }() } @@ -124,6 +124,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Logf("Hijacking CONNECT to %s", host) proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) todo.Hijack(r, proxyClient, ctx) + case ConnectHTTPMitm: proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") @@ -161,6 +162,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request return } } + case ConnectMitm: proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is TLS, mitm proxying it") @@ -261,9 +263,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request } ctx.Logf("Exiting on EOF") }() + case ConnectProxyAuthHijack: proxyClient.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n")) todo.Hijack(r, proxyClient, ctx) + case ConnectReject: if ctx.Resp != nil { if err := ctx.Resp.Write(proxyClient); err != nil { diff --git a/proxy.go b/proxy.go index fefb3bb0..f711367b 100644 --- a/proxy.go +++ b/proxy.go @@ -9,13 +9,16 @@ import ( "os" "regexp" "sync/atomic" + "github.com/gorilla/websocket" + "sync" + "fmt" ) // The basic proxy type. Implements http.Handler. type ProxyHttpServer struct { // session variable must be aligned in i386 // see http://golang.org/src/pkg/sync/atomic/doc.go#L41 - sess int64 + sess int64 // setting Verbose to true will log information on each request sent to the proxy Verbose bool Logger *log.Logger @@ -26,7 +29,9 @@ type ProxyHttpServer struct { Tr *http.Transport // ConnectDial will be used to create TCP connections for CONNECT requests // if nil Tr.Dial will be used - ConnectDial func(network string, addr string) (net.Conn, error) + ConnectDial func(network string, addr string) (net.Conn, error) + WsServer *websocket.Upgrader + WsDialer *websocket.Dialer } var hasPort = regexp.MustCompile(`:\d+$`) @@ -93,54 +98,124 @@ func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) { // Standard net/http function. Shouldn't be used directly, http.Serve will use it. func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - //r.Header["X-Forwarded-For"] = w.RemoteAddr() if r.Method == "CONNECT" { proxy.handleHttps(w, r) + } else if !r.URL.IsAbs() { + proxy.NonproxyHandler.ServeHTTP(w, r) } else { - ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy} + proxy.handleRequest(w, r) + } +} - var err error - ctx.Logf("Got request %v %v %v %v", r.URL.Path, r.Host, r.Method, r.URL.String()) - if !r.URL.IsAbs() { - proxy.NonproxyHandler.ServeHTTP(w, r) - return - } - r, resp := proxy.filterRequest(r, ctx) - - if resp == nil { - removeProxyHeaders(ctx, r) - resp, err = ctx.RoundTrip(r) - if err != nil { - ctx.Error = err - resp = proxy.filterResponse(nil, ctx) - if resp == nil { - ctx.Logf("error read response %v %v:", r.URL.Host, err.Error()) - http.Error(w, err.Error(), 500) - return +func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http.Request) { + ctx := &ProxyCtx{ + Req: base, + Session: atomic.AddInt64(&proxy.sess, 1), + Websocket: websocket.IsWebSocketUpgrade(base), + proxy: proxy, + } + + if websocket.IsWebSocketUpgrade(base) { + proto := websocket.Subprotocols(base) + wg := &sync.WaitGroup{} + + ctx.Logf("Relying websocket connection with protocols: %v", proto) + + remote, _, _ := proxy.WsDialer.Dial( + base.URL.String(), + nil, + ) + + client, _ := proxy.WsServer.Upgrade(out, base, nil) + + wg.Add(2) + + go wsRelay(ctx, remote, client, wg) + go wsRelay(ctx, client, remote, wg) + + wg.Wait() + + remote.Close() + client.Close() + + return + } + + ctx.Logf("Got request %v %v %v %v", base.URL.Path, base.Host, base.Method, base.URL.String()) + + var ( + req *http.Request + resp *http.Response + err error + ) + + req, resp = proxy.filterRequest(base, ctx) + + if resp == nil { + removeProxyHeaders(ctx, req) + + resp, err = ctx.RoundTrip(req) + + if err != nil { + ctx.Logf("Error reading response %v: %v", req.URL.Host, err.Error()) + + ctx.Error = err + resp = proxy.filterResponse(nil, ctx) + + if resp == nil { + // TODO: add gateway timeout error in case of timeout + switch err { + default: + http.Error(out, err.Error(), http.StatusBadGateway) } + + return } - ctx.Logf("Received response %v", resp.Status) - } - origBody := resp.Body - resp = proxy.filterResponse(resp, ctx) - defer origBody.Close() - ctx.Logf("Copying response to client %v [%d]", resp.Status, resp.StatusCode) - // http.ResponseWriter will take care of filling the correct response length - // Setting it now, might impose wrong value, contradicting the actual new - // body the user returned. - // We keep the original body to remove the header only if things changed. - // This will prevent problems with HEAD requests where there's no body, yet, - // the Content-Length header should be set. - if origBody != resp.Body { - resp.Header.Del("Content-Length") } - copyHeaders(w.Header(), resp.Header) - w.WriteHeader(resp.StatusCode) - nr, err := io.Copy(w, resp.Body) - if err := resp.Body.Close(); err != nil { - ctx.Warnf("Can't close response body %v", err) + } + + fmt.Printf("%v", resp) + + body := resp.Body + defer body.Close() + + resp = proxy.filterResponse(resp, ctx) + + // http.ResponseWriter will take care of filling the correct response length + // Setting it now, might impose wrong value, contradicting the actual new + // body the user returned. + // We keep the original body to remove the header only if things changed. + // This will prevent problems with HEAD requests where there's no body, yet, + // the Content-Length header should be set. + if body != resp.Body { + resp.Header.Del("Content-Length") + } + + ctx.Logf("Received response: %v", resp.Status) + ctx.Logf("Copying response to client %v [%d]", resp.Status, resp.StatusCode) + + for k, _ := range out.Header() { + out.Header().Del(k) + } + + for k, vs := range resp.Header { + for _, v := range vs { + out.Header().Add(k, v) } - ctx.Logf("Copied %v bytes to client error=%v", nr, err) + } + + out.WriteHeader(resp.StatusCode) + + nr, err := io.Copy(out, resp.Body) + + if err != nil { + ctx.Logf("Copied %v bytes to client with error: %v", nr, err) + } else { + ctx.Logf("Copied %v bytes to client", nr) + } + + if err := resp.Body.Close(); err != nil { + ctx.Warnf("Can't close response body: %v", err) } } diff --git a/websocket.go b/websocket.go new file mode 100644 index 00000000..fc849648 --- /dev/null +++ b/websocket.go @@ -0,0 +1,38 @@ +package goproxy + +import ( + "github.com/gorilla/websocket" + "io" + "sync" +) + +func wsRelay(ctx *ProxyCtx, src, dst *websocket.Conn, wg *sync.WaitGroup) { + // TODO add detection of graceful shutdown (via t == websocket.CloseMessage) + + // To avoid allocation of temp buf in io.Copy() + buf := make([]byte, 4 * 1024) + + for { + t, in, err := src.NextReader() + + if err != nil { + break + } + + out, err := dst.NextWriter(t) + + if err != nil { + break + } + + if _, err := io.CopyBuffer(out, in, buf); err != nil { + break + } + + if err := out.Close(); err != nil { + break + } + } + + wg.Done() +} From 5e71aef0f67fac55029400846c6f427100532071 Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Thu, 15 Sep 2016 17:50:09 +0300 Subject: [PATCH 3/8] WIP --- https.go | 2 +- proxy.go | 138 ++++++++++++++++++++++++++++++++----------------------- util.go | 45 ++++++++++++++++++ 3 files changed, 126 insertions(+), 59 deletions(-) create mode 100644 util.go diff --git a/https.go b/https.go index c1b4111c..2303e5c3 100644 --- a/https.go +++ b/https.go @@ -64,7 +64,7 @@ func (proxy *ProxyHttpServer) connectDial(network, addr string) (c net.Conn, err return proxy.ConnectDial(network, addr) } -func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request) { +func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Request) { ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy} hij, ok := w.(http.Hijacker) diff --git a/proxy.go b/proxy.go index f711367b..6a90c9dd 100644 --- a/proxy.go +++ b/proxy.go @@ -11,7 +11,6 @@ import ( "sync/atomic" "github.com/gorilla/websocket" "sync" - "fmt" ) // The basic proxy type. Implements http.Handler. @@ -96,10 +95,40 @@ func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) { r.Header.Del("Connection") } +func writeResponse(ctx *ProxyCtx, resp *http.Response, out http.ResponseWriter) { + ctx.Logf("Copying response to client: %v [%d]", resp.Status, resp.StatusCode) + + // 1 + for k, _ := range out.Header() { + out.Header().Del(k) + } + + for k, vs := range resp.Header { + for _, v := range vs { + out.Header().Add(k, v) + } + } + + // 2 + out.WriteHeader(resp.StatusCode) + + // 3 + if nr, err := io.Copy(out, resp.Body); err != nil { + ctx.Logf("Copied %v bytes to client with error: %v", nr, err) + } else { + ctx.Logf("Copied %v bytes to client", nr) + } + + // 4 + if err := resp.Body.Close(); err != nil { + ctx.Warnf("Can't close response body: %v", err) + } +} + // Standard net/http function. Shouldn't be used directly, http.Serve will use it. func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == "CONNECT" { - proxy.handleHttps(w, r) + proxy.handleConnect(w, r) } else if !r.URL.IsAbs() { proxy.NonproxyHandler.ServeHTTP(w, r) } else { @@ -116,65 +145,80 @@ func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http. } if websocket.IsWebSocketUpgrade(base) { - proto := websocket.Subprotocols(base) - wg := &sync.WaitGroup{} - - ctx.Logf("Relying websocket connection with protocols: %v", proto) + proxy.handleWsRequest(ctx, out, base) + } else { + proxy.handleHttpRequest(ctx, out, base) + } +} - remote, _, _ := proxy.WsDialer.Dial( - base.URL.String(), - nil, - ) +// TODO add handshake filter and message introspection +func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, out http.ResponseWriter, base *http.Request) { + proto := websocket.Subprotocols(base) + wg := &sync.WaitGroup{} - client, _ := proxy.WsServer.Upgrade(out, base, nil) + ctx.Logf("Relying websocket connection with protocols: %v", proto) - wg.Add(2) + remote, resp, err := proxy.WsDialer.Dial( + base.URL.String(), + nil, + ) - go wsRelay(ctx, remote, client, wg) - go wsRelay(ctx, client, remote, wg) + if err != nil { + if err == websocket.ErrBadHandshake { + writeResponse(ctx, resp, out) + } else { + http.Error(out, err.Error(), http.StatusBadGateway) + } - wg.Wait() + return + } - remote.Close() - client.Close() + client, err := proxy.WsServer.Upgrade(out, base, nil) + if err != nil { return } - ctx.Logf("Got request %v %v %v %v", base.URL.Path, base.Host, base.Method, base.URL.String()) + wg.Add(2) + + go wsRelay(ctx, remote, client, wg) + go wsRelay(ctx, client, remote, wg) + + wg.Wait() + remote.Close() + client.Close() + + return +} + +func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, out http.ResponseWriter, base *http.Request) { var ( req *http.Request resp *http.Response err error ) + ctx.Logf("Relying http(s) request to: %v", base.URL.String()) + req, resp = proxy.filterRequest(base, ctx) if resp == nil { removeProxyHeaders(ctx, req) - resp, err = ctx.RoundTrip(req) + } - if err != nil { - ctx.Logf("Error reading response %v: %v", req.URL.Host, err.Error()) - - ctx.Error = err - resp = proxy.filterResponse(nil, ctx) - - if resp == nil { - // TODO: add gateway timeout error in case of timeout - switch err { - default: - http.Error(out, err.Error(), http.StatusBadGateway) - } + if err != nil { + ctx.Logf("Error reading response %v: %v", req.URL.Host, err.Error()) - return - } + // TODO: add gateway timeout error in case of timeout + switch err { + default: + http.Error(out, err.Error(), http.StatusBadGateway) } - } - fmt.Printf("%v", resp) + return + } body := resp.Body defer body.Close() @@ -194,29 +238,7 @@ func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http. ctx.Logf("Received response: %v", resp.Status) ctx.Logf("Copying response to client %v [%d]", resp.Status, resp.StatusCode) - for k, _ := range out.Header() { - out.Header().Del(k) - } - - for k, vs := range resp.Header { - for _, v := range vs { - out.Header().Add(k, v) - } - } - - out.WriteHeader(resp.StatusCode) - - nr, err := io.Copy(out, resp.Body) - - if err != nil { - ctx.Logf("Copied %v bytes to client with error: %v", nr, err) - } else { - ctx.Logf("Copied %v bytes to client", nr) - } - - if err := resp.Body.Close(); err != nil { - ctx.Warnf("Can't close response body: %v", err) - } + writeResponse(ctx, resp, out) } // New proxy server, logs to StdErr by default diff --git a/util.go b/util.go new file mode 100644 index 00000000..1994bf9e --- /dev/null +++ b/util.go @@ -0,0 +1,45 @@ +package goproxy + +import ( + "net/http" + "net" + "bufio" + "errors" +) + +// This response writer can be hijacked multiple times +type hijackedResponseWriter struct { + nested http.ResponseWriter + conn net.Conn + err error +} + +func (writer *hijackedResponseWriter) Header() http.Header { + return writer.nested.Header() +} + +func (writer *hijackedResponseWriter) Write(data []byte) (int, error) { + return writer.nested.Write(data) +} + +func (writer *hijackedResponseWriter) WriteHeader(code int) { + writer.nested.WriteHeader(code) +} + +func (writer *hijackedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := writer.nested.(http.Hijacker) + + if !ok { + return nil, nil, errors.New("proxy: nested http.ResponseWriter does not implement http.Hijacker interface") + } + + if !writer.hijacked() { + writer.conn, _, writer.err = hijacker.Hijack() + } + + return writer.conn, nil, writer.err +} + +func (writer *hijackedResponseWriter) hijacked() bool { + return writer.conn != nil || writer.err != nil +} From bec814acb40de81ea41f3a02600ce6edeb51bedc Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Thu, 15 Sep 2016 18:57:06 +0300 Subject: [PATCH 4/8] http websocket finally working --- https.go | 40 +++++++++++++--------------------------- proxy.go | 30 ++++++++++++++++++++++++++++-- util.go | 14 ++++++++++++-- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/https.go b/https.go index 2303e5c3..f62a5900 100644 --- a/https.go +++ b/https.go @@ -66,6 +66,7 @@ func (proxy *ProxyHttpServer) connectDial(network, addr string) (c net.Conn, err func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Request) { ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy} + w = NewHijackedResponseWriter(w) hij, ok := w.(http.Hijacker) if !ok { @@ -126,40 +127,25 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque todo.Hijack(r, proxyClient, ctx) case ConnectHTTPMitm: - proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") - targetSiteCon, err := proxy.connectDial("tcp", host) - if err != nil { - ctx.Warnf("Error dialing to %s: %s", host, err.Error()) - return - } + proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + for { client := bufio.NewReader(proxyClient) - remote := bufio.NewReader(targetSiteCon) req, err := http.ReadRequest(client) - if err != nil && err != io.EOF { - ctx.Warnf("cannot read request of MITM HTTP client: %+#v", err) - } + if err != nil { - return - } - req, resp := proxy.filterRequest(req, ctx) - if resp == nil { - if err := req.Write(targetSiteCon); err != nil { - httpError(proxyClient, ctx, err) - return + if err != io.EOF { + ctx.Warnf("cannot read request of MITM HTTP client: %+#v", err) } - resp, err = http.ReadResponse(remote, req) - if err != nil { - httpError(proxyClient, ctx, err) - return - } - defer resp.Body.Close() + + break } - resp = proxy.filterResponse(resp, ctx) - if err := resp.Write(proxyClient); err != nil { - httpError(proxyClient, ctx, err) - return + + req.URL, err = url.Parse("http://" + req.Host + req.URL.String()) + + if err := proxy.handleRequest(w, req); err != nil { + break } } diff --git a/proxy.go b/proxy.go index 6a90c9dd..e6f6112e 100644 --- a/proxy.go +++ b/proxy.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "github.com/gorilla/websocket" "sync" + "errors" ) // The basic proxy type. Implements http.Handler. @@ -33,6 +34,8 @@ type ProxyHttpServer struct { WsDialer *websocket.Dialer } +var ErrConnectionClosed = errors.New("http: no Location header in response") + var hasPort = regexp.MustCompile(`:\d+$`) func copyHeaders(dst, src http.Header) { @@ -128,15 +131,18 @@ func writeResponse(ctx *ProxyCtx, resp *http.Response, out http.ResponseWriter) // Standard net/http function. Shouldn't be used directly, http.Serve will use it. func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == "CONNECT" { + // CONNECT requests: SSL and WebSockets proxy.handleConnect(w, r) } else if !r.URL.IsAbs() { + // Local requests proxy.NonproxyHandler.ServeHTTP(w, r) } else { + // Common HTTP proxy proxy.handleRequest(w, r) } } -func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http.Request) { +func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http.Request) error { ctx := &ProxyCtx{ Req: base, Session: atomic.AddInt64(&proxy.sess, 1), @@ -149,6 +155,8 @@ func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http. } else { proxy.handleHttpRequest(ctx, out, base) } + + return nil } // TODO add handshake filter and message introspection @@ -156,7 +164,15 @@ func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, out http.ResponseWr proto := websocket.Subprotocols(base) wg := &sync.WaitGroup{} - ctx.Logf("Relying websocket connection with protocols: %v", proto) + switch base.URL.Scheme { + case "http": + base.URL.Scheme = "ws" + + case "https": + base.URL.Scheme = "wss" + } + + ctx.Logf("Relying websocket connection %s with protocols: %v", base.URL.String(), proto) remote, resp, err := proxy.WsDialer.Dial( base.URL.String(), @@ -164,6 +180,8 @@ func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, out http.ResponseWr ) if err != nil { + ctx.Logf(err.Error()) + if err == websocket.ErrBadHandshake { writeResponse(ctx, resp, out) } else { @@ -176,6 +194,8 @@ func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, out http.ResponseWr client, err := proxy.WsServer.Upgrade(out, base, nil) if err != nil { + ctx.Logf(err.Error()) + return } @@ -253,6 +273,12 @@ func NewProxyHttpServer() *ProxyHttpServer { }), Tr: &http.Transport{TLSClientConfig: tlsClientSkipVerify, Proxy: http.ProxyFromEnvironment}, + + WsServer: &websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, } proxy.ConnectDial = dialerFromEnv(&proxy) return &proxy diff --git a/util.go b/util.go index 1994bf9e..ab3680bd 100644 --- a/util.go +++ b/util.go @@ -9,8 +9,12 @@ import ( // This response writer can be hijacked multiple times type hijackedResponseWriter struct { + // base writer nested http.ResponseWriter + + // HiJacked fields conn net.Conn + rw *bufio.ReadWriter err error } @@ -34,12 +38,18 @@ func (writer *hijackedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, err } if !writer.hijacked() { - writer.conn, _, writer.err = hijacker.Hijack() + writer.conn, writer.rw, writer.err = hijacker.Hijack() } - return writer.conn, nil, writer.err + return writer.conn, writer.rw, writer.err } func (writer *hijackedResponseWriter) hijacked() bool { return writer.conn != nil || writer.err != nil } + +func NewHijackedResponseWriter(nested http.ResponseWriter) *hijackedResponseWriter { + return &hijackedResponseWriter{ + nested: nested, + } +} From 8caa3d91653efee16479d756220c98023d67d375 Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Fri, 16 Sep 2016 19:46:59 +0300 Subject: [PATCH 5/8] connResponseWriter some critical bugs are here --- examples/goproxy-websocket/main.go | 4 -- https.go | 78 ++++++------------------------ proxy.go | 50 ++++++++++--------- util.go | 56 ++++++++++----------- 4 files changed, 67 insertions(+), 121 deletions(-) diff --git a/examples/goproxy-websocket/main.go b/examples/goproxy-websocket/main.go index 83149ed6..bd21cd19 100644 --- a/examples/goproxy-websocket/main.go +++ b/examples/goproxy-websocket/main.go @@ -16,10 +16,6 @@ func main() { // MitM proxy.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { - if ctx.Req.Header.Get("Connection") == "Upgrade" { - return goproxy.RejectConnect, host - } - if https.MatchString(host) { return goproxy.MitmConnect, host } else { diff --git a/https.go b/https.go index f62a5900..adf0548a 100644 --- a/https.go +++ b/https.go @@ -11,7 +11,6 @@ import ( "net/url" "os" "regexp" - "strconv" "strings" "sync" "sync/atomic" @@ -66,7 +65,6 @@ func (proxy *ProxyHttpServer) connectDial(network, addr string) (c net.Conn, err func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Request) { ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy} - w = NewHijackedResponseWriter(w) hij, ok := w.(http.Hijacker) if !ok { @@ -144,7 +142,11 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque req.URL, err = url.Parse("http://" + req.Host + req.URL.String()) - if err := proxy.handleRequest(w, req); err != nil { + if end, err := proxy.handleRequest(NewConnResponseWriter(proxyClient), req); end { + if err != nil { + ctx.Warnf("Error during serving MITM HTTP request: %+#v", err) + } + break } } @@ -176,75 +178,23 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque clientTlsReader := bufio.NewReader(rawClientTls) for !isEof(clientTlsReader) { req, err := http.ReadRequest(clientTlsReader) - if err != nil && err != io.EOF { - return - } + if err != nil { - ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err) - return - } - req.RemoteAddr = r.RemoteAddr // since we're converting the request, need to carry over the original connecting IP as well - ctx.Logf("req %v", r.Host) + if err != io.EOF { + ctx.Warnf("cannot read request of MITM HTTPS client: %+#v", err) + } - if !httpsRegexp.MatchString(req.URL.String()) { - req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) + break } - // Bug fix which goproxy fails to provide request - // information URL in the context when does HTTPS MITM - ctx.Req = req + req.URL, err = url.Parse("https://" + req.Host + req.URL.String()) - req, resp := proxy.filterRequest(req, ctx) - if resp == nil { - if err != nil { - ctx.Warnf("Illegal URL %s", "https://"+r.Host+req.URL.Path) - return - } - removeProxyHeaders(ctx, req) - resp, err = ctx.RoundTrip(req) + if end, err := proxy.handleRequest(NewConnResponseWriter(rawClientTls), req); end { if err != nil { - ctx.Warnf("Cannot read TLS response from mitm'd server %v", err) - return + ctx.Warnf("Error during serving MITM HTTPS request: %+#v", err) } - ctx.Logf("resp %v", resp.Status) - } - resp = proxy.filterResponse(resp, ctx) - defer resp.Body.Close() - text := resp.Status - statusCode := strconv.Itoa(resp.StatusCode) + " " - if strings.HasPrefix(text, statusCode) { - text = text[len(statusCode):] - } - // always use 1.1 to support chunked encoding - if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+text+"\r\n"); err != nil { - ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err) - return - } - // Since we don't know the length of resp, return chunked encoded response - // TODO: use a more reasonable scheme - resp.Header.Del("Content-Length") - resp.Header.Set("Transfer-Encoding", "chunked") - if err := resp.Header.Write(rawClientTls); err != nil { - ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err) - return - } - if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { - ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err) - return - } - chunked := newChunkedWriter(rawClientTls) - if _, err := io.Copy(chunked, resp.Body); err != nil { - ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err) - return - } - if err := chunked.Close(); err != nil { - ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err) - return - } - if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { - ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err) - return + break } } ctx.Logf("Exiting on EOF") diff --git a/proxy.go b/proxy.go index e6f6112e..5e0f9991 100644 --- a/proxy.go +++ b/proxy.go @@ -11,7 +11,6 @@ import ( "sync/atomic" "github.com/gorilla/websocket" "sync" - "errors" ) // The basic proxy type. Implements http.Handler. @@ -34,8 +33,6 @@ type ProxyHttpServer struct { WsDialer *websocket.Dialer } -var ErrConnectionClosed = errors.New("http: no Location header in response") - var hasPort = regexp.MustCompile(`:\d+$`) func copyHeaders(dst, src http.Header) { @@ -101,6 +98,16 @@ func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) { func writeResponse(ctx *ProxyCtx, resp *http.Response, out http.ResponseWriter) { ctx.Logf("Copying response to client: %v [%d]", resp.Status, resp.StatusCode) + // Fancy ResponseWriter + if w, ok := out.(*connResponseWriter); ok { + if err := resp.Write(w); err != nil { + ctx.Warnf("Error copying response: %s", err.Error()) + } + + return + } + + // Standard ResponseWriter // 1 for k, _ := range out.Header() { out.Header().Del(k) @@ -142,7 +149,7 @@ func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) } } -func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http.Request) error { +func (proxy *ProxyHttpServer) handleRequest(writer http.ResponseWriter, base *http.Request) (bool, error) { ctx := &ProxyCtx{ Req: base, Session: atomic.AddInt64(&proxy.sess, 1), @@ -151,16 +158,14 @@ func (proxy *ProxyHttpServer) handleRequest(out http.ResponseWriter, base *http. } if websocket.IsWebSocketUpgrade(base) { - proxy.handleWsRequest(ctx, out, base) + return proxy.handleWsRequest(ctx, writer, base) } else { - proxy.handleHttpRequest(ctx, out, base) + return proxy.handleHttpRequest(ctx, writer, base) } - - return nil } // TODO add handshake filter and message introspection -func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, out http.ResponseWriter, base *http.Request) { +func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, writer http.ResponseWriter, base *http.Request) (bool, error) { proto := websocket.Subprotocols(base) wg := &sync.WaitGroup{} @@ -180,23 +185,19 @@ func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, out http.ResponseWr ) if err != nil { - ctx.Logf(err.Error()) - if err == websocket.ErrBadHandshake { - writeResponse(ctx, resp, out) + writeResponse(ctx, resp, writer) } else { - http.Error(out, err.Error(), http.StatusBadGateway) + http.Error(writer, err.Error(), http.StatusBadGateway) } - return + return true, err } - client, err := proxy.WsServer.Upgrade(out, base, nil) + client, err := proxy.WsServer.Upgrade(writer, base, nil) if err != nil { - ctx.Logf(err.Error()) - - return + return true, err } wg.Add(2) @@ -209,10 +210,10 @@ func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, out http.ResponseWr remote.Close() client.Close() - return + return true, nil } -func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, out http.ResponseWriter, base *http.Request) { +func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, writer http.ResponseWriter, base *http.Request) (bool, error) { var ( req *http.Request resp *http.Response @@ -234,10 +235,10 @@ func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, out http.Response // TODO: add gateway timeout error in case of timeout switch err { default: - http.Error(out, err.Error(), http.StatusBadGateway) + http.Error(writer, err.Error(), http.StatusBadGateway) } - return + return false, err } body := resp.Body @@ -258,7 +259,9 @@ func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, out http.Response ctx.Logf("Received response: %v", resp.Status) ctx.Logf("Copying response to client %v [%d]", resp.Status, resp.StatusCode) - writeResponse(ctx, resp, out) + writeResponse(ctx, resp, writer) + + return false, err } // New proxy server, logs to StdErr by default @@ -274,6 +277,7 @@ func NewProxyHttpServer() *ProxyHttpServer { Tr: &http.Transport{TLSClientConfig: tlsClientSkipVerify, Proxy: http.ProxyFromEnvironment}, + WsDialer: &websocket.Dialer{}, WsServer: &websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true diff --git a/util.go b/util.go index ab3680bd..85e5cd61 100644 --- a/util.go +++ b/util.go @@ -1,55 +1,51 @@ package goproxy import ( - "net/http" - "net" "bufio" "errors" + "io" + "io/ioutil" + "net" + "net/http" ) -// This response writer can be hijacked multiple times -type hijackedResponseWriter struct { - // base writer - nested http.ResponseWriter - - // HiJacked fields - conn net.Conn - rw *bufio.ReadWriter - err error +type connResponseWriter struct { + dst io.Writer } -func (writer *hijackedResponseWriter) Header() http.Header { - return writer.nested.Header() +func (w *connResponseWriter) Header() http.Header { + panic("proxy: ConnectionResponseWriter does not implement Header()") } -func (writer *hijackedResponseWriter) Write(data []byte) (int, error) { - return writer.nested.Write(data) +func (w *connResponseWriter) Write(data []byte) (int, error) { + return w.dst.Write(data) } -func (writer *hijackedResponseWriter) WriteHeader(code int) { - writer.nested.WriteHeader(code) +func (w *connResponseWriter) WriteHeader(code int) { + panic("proxy: ConnectionResponseWriter does not implement WriteHeader(int)") } -func (writer *hijackedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hijacker, ok := writer.nested.(http.Hijacker) +func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn, ok := w.dst.(net.Conn) if !ok { - return nil, nil, errors.New("proxy: nested http.ResponseWriter does not implement http.Hijacker interface") + return nil, nil, errors.New("proxy: nested io.Writer does not implement net.Conn interface") } - if !writer.hijacked() { - writer.conn, writer.rw, writer.err = hijacker.Hijack() - } + rw := bufio.NewReadWriter( + bufio.NewReader(io.MultiReader()), + bufio.NewWriter(ioutil.Discard), + ) - return writer.conn, writer.rw, writer.err + return conn, rw, nil } -func (writer *hijackedResponseWriter) hijacked() bool { - return writer.conn != nil || writer.err != nil +func NewConnResponseWriter(dst io.Writer) *connResponseWriter { + return &connResponseWriter{ + dst: dst, + } } -func NewHijackedResponseWriter(nested http.ResponseWriter) *hijackedResponseWriter { - return &hijackedResponseWriter{ - nested: nested, - } +func Error(w http.ResponseWriter, error string, code int) { + } From 255655fca03fd3d508588f84f08bc8732d24590d Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Mon, 19 Sep 2016 17:24:51 +0300 Subject: [PATCH 6/8] fix content len bug when serializing to conenction --- https.go | 55 ++++++++++++++++++++++++++++++------------------------- proxy.go | 54 ++++++++++++++++++++++++++++++++---------------------- 2 files changed, 62 insertions(+), 47 deletions(-) diff --git a/https.go b/https.go index adf0548a..66516fb3 100644 --- a/https.go +++ b/https.go @@ -128,8 +128,9 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + client := bufio.NewReader(proxyClient) + for { - client := bufio.NewReader(proxyClient) req, err := http.ReadRequest(client) if err != nil { @@ -151,6 +152,8 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque } } + proxyClient.Close() + case ConnectMitm: proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is TLS, mitm proxying it") @@ -167,38 +170,40 @@ func (proxy *ProxyHttpServer) handleConnect(w http.ResponseWriter, r *http.Reque return } } - go func() { - //TODO: cache connections to the remote website - rawClientTls := tls.Server(proxyClient, tlsConfig) - if err := rawClientTls.Handshake(); err != nil { - ctx.Warnf("Cannot handshake client %v %v", r.Host, err) - return - } - defer rawClientTls.Close() - clientTlsReader := bufio.NewReader(rawClientTls) - for !isEof(clientTlsReader) { - req, err := http.ReadRequest(clientTlsReader) - if err != nil { - if err != io.EOF { - ctx.Warnf("cannot read request of MITM HTTPS client: %+#v", err) - } + //TODO: cache connections to the remote website + rawClientTls := tls.Server(proxyClient, tlsConfig) + if err := rawClientTls.Handshake(); err != nil { + ctx.Warnf("Cannot handshake client %v %v", r.Host, err) + return + } + defer rawClientTls.Close() + clientTlsReader := bufio.NewReader(rawClientTls) + + for { + req, err := http.ReadRequest(clientTlsReader) - break + if err != nil { + if err != io.EOF { + ctx.Warnf("cannot read request of MITM HTTPS client: %+#v", err) } - req.URL, err = url.Parse("https://" + req.Host + req.URL.String()) + break + } - if end, err := proxy.handleRequest(NewConnResponseWriter(rawClientTls), req); end { - if err != nil { - ctx.Warnf("Error during serving MITM HTTPS request: %+#v", err) - } + req.URL, err = url.Parse("https://" + req.Host + req.URL.String()) - break + if end, err := proxy.handleRequest(NewConnResponseWriter(rawClientTls), req); end { + if err != nil { + ctx.Warnf("Error during serving MITM HTTPS request: %+#v", err) } + + break } - ctx.Logf("Exiting on EOF") - }() + } + + ctx.Logf("Exiting on EOF") + proxyClient.Close() case ConnectProxyAuthHijack: proxyClient.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n")) diff --git a/proxy.go b/proxy.go index 5e0f9991..2485599d 100644 --- a/proxy.go +++ b/proxy.go @@ -1,7 +1,6 @@ package goproxy import ( - "bufio" "io" "log" "net" @@ -11,6 +10,8 @@ import ( "sync/atomic" "github.com/gorilla/websocket" "sync" + "io/ioutil" + "bytes" ) // The basic proxy type. Implements http.Handler. @@ -35,25 +36,6 @@ type ProxyHttpServer struct { var hasPort = regexp.MustCompile(`:\d+$`) -func copyHeaders(dst, src http.Header) { - for k, _ := range dst { - dst.Del(k) - } - for k, vs := range src { - for _, v := range vs { - dst.Add(k, v) - } - } -} - -func isEof(r *bufio.Reader) bool { - _, err := r.Peek(1) - if err == io.EOF { - return true - } - return false -} - func (proxy *ProxyHttpServer) filterRequest(r *http.Request, ctx *ProxyCtx) (req *http.Request, resp *http.Response) { req = r for _, h := range proxy.reqHandlers { @@ -96,12 +78,41 @@ func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) { } func writeResponse(ctx *ProxyCtx, resp *http.Response, out http.ResponseWriter) { - ctx.Logf("Copying response to client: %v [%d]", resp.Status, resp.StatusCode) + ctx.Logf("Copying response to client: %v (%d bytes)", resp.Status, resp.ContentLength) // Fancy ResponseWriter if w, ok := out.(*connResponseWriter); ok { + // net/http: Response.Write produces invalid responses in this case, + // hacking to fix that + if resp.ContentLength == -1 { + defer resp.Body.Close() + + peek, err := ioutil.ReadAll( + io.LimitReader(resp.Body, 4 * 1024), + ) + + body := bytes.NewReader(peek) + + if err != nil { + ctx.Warnf("Error copying response: %s", err.Error()) + } + + if len(peek) < 4 * 1024 { + resp.ContentLength = int64(body.Len()) + resp.Body = ioutil.NopCloser(body) + } else { + resp.TransferEncoding = append(resp.TransferEncoding, "chunked") + resp.Body = ioutil.NopCloser(io.MultiReader( + body, + resp.Body, + )) + } + } + if err := resp.Write(w); err != nil { ctx.Warnf("Error copying response: %s", err.Error()) + } else { + ctx.Logf("Copied response to client") } return @@ -257,7 +268,6 @@ func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, writer http.Respo } ctx.Logf("Received response: %v", resp.Status) - ctx.Logf("Copying response to client %v [%d]", resp.Status, resp.StatusCode) writeResponse(ctx, resp, writer) From 660240dc19fbd8f55154bafa0539e00a4d2d350a Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Mon, 19 Sep 2016 17:50:43 +0300 Subject: [PATCH 7/8] refactored and moved header cleaning --- proxy.go | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/proxy.go b/proxy.go index 2485599d..583b7e77 100644 --- a/proxy.go +++ b/proxy.go @@ -57,24 +57,18 @@ func (proxy *ProxyHttpServer) filterResponse(respOrig *http.Response, ctx *Proxy return } -func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) { - r.RequestURI = "" // this must be reset when serving a request with the client - ctx.Logf("Sending request %v %v", r.Method, r.URL.String()) - // If no Accept-Encoding header exists, Transport will add the headers it can accept - // and would wrap the response body with the relevant reader. - r.Header.Del("Accept-Encoding") - // curl can add that, see - // https://jdebp.eu./FGA/web-proxy-connection-header.html - r.Header.Del("Proxy-Connection") - r.Header.Del("Proxy-Authenticate") - r.Header.Del("Proxy-Authorization") - // Connection, Authenticate and Authorization are single hop Header: - // http://www.w3.org/Protocols/rfc2616/rfc2616.txt - // 14.10 Connection - // The Connection general-header field allows the sender to specify - // options that are desired for that particular connection and MUST NOT - // be communicated by proxies over further connections. - r.Header.Del("Connection") +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", } func writeResponse(ctx *ProxyCtx, resp *http.Response, out http.ResponseWriter) { @@ -168,6 +162,9 @@ func (proxy *ProxyHttpServer) handleRequest(writer http.ResponseWriter, base *ht proxy: proxy, } + // Clean-up + base.RequestURI = "" + if websocket.IsWebSocketUpgrade(base) { return proxy.handleWsRequest(ctx, writer, base) } else { @@ -236,8 +233,22 @@ func (proxy *ProxyHttpServer) handleHttpRequest(ctx *ProxyCtx, writer http.Respo req, resp = proxy.filterRequest(base, ctx) if resp == nil { - removeProxyHeaders(ctx, req) + // If no Accept-Encoding header exists, Transport will add the headers it can accept + // and would wrap the response body with the relevant reader. + req.Header.Del("Accept-Encoding") + + // Clean-up request + for _, h := range hopHeaders { + req.Header.Del(h) + } + + // Process resp, err = ctx.RoundTrip(req) + + // Clean-up response + for _, h := range hopHeaders { + resp.Header.Del(h) + } } if err != nil { From 6a73368a7b7a87c2610cae9cc7b7ee15c7faac5c Mon Sep 17 00:00:00 2001 From: Pavel Eremeev Date: Mon, 19 Sep 2016 18:08:42 +0300 Subject: [PATCH 8/8] updated error function --- proxy.go | 5 ++++- util.go | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/proxy.go b/proxy.go index 583b7e77..ff66e667 100644 --- a/proxy.go +++ b/proxy.go @@ -196,7 +196,7 @@ func (proxy *ProxyHttpServer) handleWsRequest(ctx *ProxyCtx, writer http.Respons if err == websocket.ErrBadHandshake { writeResponse(ctx, resp, writer) } else { - http.Error(writer, err.Error(), http.StatusBadGateway) + Error(writer, err.Error(), http.StatusBadGateway) } return true, err @@ -303,6 +303,9 @@ func NewProxyHttpServer() *ProxyHttpServer { CheckOrigin: func(r *http.Request) bool { return true }, + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { + Error(w, reason.Error(), status) + }, }, } proxy.ConnectDial = dialerFromEnv(&proxy) diff --git a/util.go b/util.go index 85e5cd61..76b7fd92 100644 --- a/util.go +++ b/util.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net" "net/http" + "strings" ) type connResponseWriter struct { @@ -46,6 +47,19 @@ func NewConnResponseWriter(dst io.Writer) *connResponseWriter { } } -func Error(w http.ResponseWriter, error string, code int) { +func Error(out http.ResponseWriter, error string, code int) { + resp := &http.Response{ + StatusCode: code, + ContentLength: -1, + Body: ioutil.NopCloser(strings.NewReader(error)), + } + + ctx := &ProxyCtx{ + Req: nil, + Session: 0, + Websocket: false, + proxy: nil, + } + writeResponse(ctx, resp, out) }