From 732c199519d2eecafea40f2fbdb3d0b525593806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=9F=E5=A4=A7=E9=99=86?= <10188142zhong_dalu@cn.tre-inc.com> Date: Mon, 25 Sep 2023 14:56:33 +0800 Subject: [PATCH] improve & add unit test --- middleware.go | 29 +++++++++++++---------------- middleware_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/middleware.go b/middleware.go index e5c4eef7..9bdac5f2 100644 --- a/middleware.go +++ b/middleware.go @@ -21,9 +21,9 @@ import ( const debugRequestLogKey = "__restyDebugRequestLog" -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Request Middleware(s) -//_______________________________________________________________________ +// _______________________________________________________________________ func parseRequestURL(c *Client, r *Request) error { // GitHub #103 Path Params @@ -183,6 +183,7 @@ CL: } func createHTTPRequest(c *Client, r *Request) (err error) { + var bodyCopy = acquireBuffer() if r.bodyBuf == nil { if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) { r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader) @@ -191,9 +192,12 @@ func createHTTPRequest(c *Client, r *Request) (err error) { } else { r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil) } + bodyCopy, err = getRawRequestBodyCopy(r) + if err != nil { + return err + } } else { // deep copy - bodyCopy := acquireBuffer() _, err := io.Copy(bodyCopy, bytes.NewReader(r.bodyBuf.Bytes())) if err != nil { return err @@ -232,17 +236,10 @@ func createHTTPRequest(c *Client, r *Request) (err error) { r.RawRequest = r.RawRequest.WithContext(r.ctx) } - if r.bodyBuf == nil { - bodyCopy, err := getRawRequestBodyCopy(r) - if err != nil { - return err - } - - if bodyCopy != nil { - // assign get body func for the underlying raw request instance - r.RawRequest.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(bytes.NewReader(bodyCopy.Bytes())), nil - } + if bodyCopy != nil && bodyCopy.Len() > 0 { + // assign get body func for the underlying raw request instance + r.RawRequest.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(bodyCopy.Bytes())), nil } } @@ -312,9 +309,9 @@ func requestLogger(c *Client, r *Request) error { return nil } -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Response Middleware(s) -//_______________________________________________________________________ +// _______________________________________________________________________ func responseLogger(c *Client, res *Response) error { if res.Request.Debug { diff --git a/middleware_test.go b/middleware_test.go index fef6f008..84a614b8 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -1,6 +1,7 @@ package resty import ( + "bytes" "net/url" "testing" ) @@ -227,3 +228,47 @@ func Test_parseRequestURL(t *testing.T) { }) } } + +func Test_createHTTPRequest(t *testing.T) { + type args struct { + c *Client + r *Request + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "bodyBuf is not nil, deep copy", + args: args{ + c: &Client{}, + r: func() *Request { + req := &Request{} + req.bodyBuf = bytes.NewBufferString("test") + return req + }(), + }, + wantErr: false, + }, + { + name: "bodyBuf is nil, deep copy", + args: args{ + c: &Client{}, + r: func() *Request { + req := &Request{} + req.bodyBuf = nil + return req + }(), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := createHTTPRequest(tt.args.c, tt.args.r); (err != nil) != tt.wantErr { + t.Errorf("createHTTPRequest() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}