diff --git a/middleware.go b/middleware.go index 162342a5..e5c4eef7 100644 --- a/middleware.go +++ b/middleware.go @@ -192,9 +192,13 @@ func createHTTPRequest(c *Client, r *Request) (err error) { r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil) } } else { - // fix data race: must deep copy. - bodyBuf := bytes.NewBuffer(append([]byte{}, r.bodyBuf.Bytes()...)) - r.RawRequest, err = http.NewRequest(r.Method, r.URL, bodyBuf) + // deep copy + bodyCopy := acquireBuffer() + _, err := io.Copy(bodyCopy, bytes.NewReader(r.bodyBuf.Bytes())) + if err != nil { + return err + } + r.RawRequest, err = http.NewRequest(r.Method, r.URL, bodyCopy) } if err != nil { @@ -228,17 +232,18 @@ func createHTTPRequest(c *Client, r *Request) (err error) { r.RawRequest = r.RawRequest.WithContext(r.ctx) } - bodyCopy, err := getBodyCopy(r) - if err != nil { - return err - } + if r.bodyBuf == nil { + bodyCopy, err := getRawRequestBodyCopy(r) + if err != nil { + return err + } - // assign get body func for the underlying raw request instance - r.RawRequest.GetBody = func() (io.ReadCloser, error) { if bodyCopy != nil { - return io.NopCloser(bytes.NewReader(bodyCopy.Bytes())), nil + // assign get body func for the underlying raw request instance + r.RawRequest.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(bodyCopy.Bytes())), nil + } } - return nil, nil } return @@ -544,17 +549,7 @@ func saveResponseIntoFile(c *Client, res *Response) error { return nil } -func getBodyCopy(r *Request) (*bytes.Buffer, error) { - // If r.bodyBuf present, return the copy - if r.bodyBuf != nil { - bodyCopy := acquireBuffer() - if _, err := io.Copy(bodyCopy, bytes.NewReader(r.bodyBuf.Bytes())); err != nil { - // cannot use io.Copy(bodyCopy, r.bodyBuf) because io.Copy reset r.bodyBuf - return nil, err - } - return bodyCopy, nil - } - +func getRawRequestBodyCopy(r *Request) (*bytes.Buffer, error) { // Maybe body is `io.Reader`. // Note: Resty user have to watchout for large body size of `io.Reader` if r.RawRequest.Body != nil { diff --git a/middleware_test.go b/middleware_test.go index fef6f008..84a614b8 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -1,6 +1,7 @@ package resty import ( + "bytes" "net/url" "testing" ) @@ -227,3 +228,47 @@ func Test_parseRequestURL(t *testing.T) { }) } } + +func Test_createHTTPRequest(t *testing.T) { + type args struct { + c *Client + r *Request + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "bodyBuf is not nil, deep copy", + args: args{ + c: &Client{}, + r: func() *Request { + req := &Request{} + req.bodyBuf = bytes.NewBufferString("test") + return req + }(), + }, + wantErr: false, + }, + { + name: "bodyBuf is nil, deep copy", + args: args{ + c: &Client{}, + r: func() *Request { + req := &Request{} + req.bodyBuf = nil + return req + }(), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := createHTTPRequest(tt.args.c, tt.args.r); (err != nil) != tt.wantErr { + t.Errorf("createHTTPRequest() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}