Skip to content

Commit

Permalink
feat: add mutex to make Client thread-safe (#827)
Browse files Browse the repository at this point in the history
  • Loading branch information
TurtleRuss authored Sep 2, 2024
1 parent 0953f5a commit f8a7343
Show file tree
Hide file tree
Showing 10 changed files with 507 additions and 162 deletions.
457 changes: 393 additions & 64 deletions client.go

Large diffs are not rendered by default.

65 changes: 32 additions & 33 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func TestClientAuthScheme(t *testing.T) {

// Ensure setting the scheme works as well
c.SetAuthScheme("Bearer")
assertEqual(t, "Bearer", c.AuthScheme())

resp2, err2 := c.R().Get("/profile")
assertError(t, err2)
Expand Down Expand Up @@ -240,7 +241,7 @@ func TestClientProxy(t *testing.T) {
assertNotNil(t, resp)
assertNotNil(t, err)

// Error
// error
c.SetProxy("//not.a.user@%66%6f%6f.com:8888")

resp, err = c.R().
Expand Down Expand Up @@ -339,9 +340,9 @@ func TestClientSetHeaderVerbatim(t *testing.T) {
SetHeader("header-lowercase", "value_standard")

//lint:ignore SA1008 valid one, so ignore this!
unConventionHdrValue := strings.Join(c.Header["header-lowercase"], "")
unConventionHdrValue := strings.Join(c.Header()["header-lowercase"], "")
assertEqual(t, "value_lowercase", unConventionHdrValue)
assertEqual(t, "value_standard", c.Header.Get("Header-Lowercase"))
assertEqual(t, "value_standard", c.Header().Get("Header-Lowercase"))
}

func TestClientSetTransport(t *testing.T) {
Expand Down Expand Up @@ -387,20 +388,20 @@ func TestClientOptions(t *testing.T) {
assertEqual(t, client.setContentLength, true)

client.SetBaseURL("http://httpbin.org")
assertEqual(t, "http://httpbin.org", client.BaseURL)
assertEqual(t, "http://httpbin.org", client.BaseURL())

client.SetHeader(hdrContentTypeKey, "application/json; charset=utf-8")
client.SetHeaders(map[string]string{
hdrUserAgentKey: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) go-resty v0.1",
"X-Request-Id": strconv.FormatInt(time.Now().UnixNano(), 10),
})
assertEqual(t, "application/json; charset=utf-8", client.Header.Get(hdrContentTypeKey))
assertEqual(t, "application/json; charset=utf-8", client.Header().Get(hdrContentTypeKey))

client.SetCookie(&http.Cookie{
Name: "default-cookie",
Value: "This is cookie default-cookie value",
})
assertEqual(t, "default-cookie", client.Cookies[0].Name)
assertEqual(t, "default-cookie", client.Cookies()[0].Name)

cookies := []*http.Cookie{
{
Expand All @@ -412,45 +413,45 @@ func TestClientOptions(t *testing.T) {
},
}
client.SetCookies(cookies)
assertEqual(t, "default-cookie-1", client.Cookies[1].Name)
assertEqual(t, "default-cookie-2", client.Cookies[2].Name)
assertEqual(t, "default-cookie-1", client.Cookies()[1].Name)
assertEqual(t, "default-cookie-2", client.Cookies()[2].Name)

client.SetQueryParam("test_param_1", "Param_1")
client.SetQueryParams(map[string]string{"test_param_2": "Param_2", "test_param_3": "Param_3"})
assertEqual(t, "Param_3", client.QueryParam.Get("test_param_3"))
assertEqual(t, "Param_3", client.QueryParam().Get("test_param_3"))

rTime := strconv.FormatInt(time.Now().UnixNano(), 10)
client.SetFormData(map[string]string{"r_time": rTime})
assertEqual(t, rTime, client.FormData.Get("r_time"))
assertEqual(t, rTime, client.FormData().Get("r_time"))

client.SetBasicAuth("myuser", "mypass")
assertEqual(t, "myuser", client.UserInfo.Username)
assertEqual(t, "myuser", client.BasicAuth().Username)

client.SetAuthToken("AC75BD37F019E08FBC594900518B4F7E")
assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token)
assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token())

client.SetDisableWarn(true)
assertEqual(t, client.DisableWarn, true)
assertEqual(t, client.DisableWarn(), true)

client.SetRetryCount(3)
assertEqual(t, 3, client.RetryCount)
assertEqual(t, 3, client.RetryCount())

rwt := time.Duration(1000) * time.Millisecond
client.SetRetryWaitTime(rwt)
assertEqual(t, rwt, client.RetryWaitTime)
assertEqual(t, rwt, client.RetryWaitTime())

mrwt := time.Duration(2) * time.Second
client.SetRetryMaxWaitTime(mrwt)
assertEqual(t, mrwt, client.RetryMaxWaitTime)
assertEqual(t, mrwt, client.RetryMaxWaitTime())

client.AddRetryAfterErrorCondition()
equal(client.RetryConditions[0], func(response *Response, err error) bool {
equal(client.RetryConditions()[0], func(response *Response, err error) bool {
return response.IsError()
})

err := &AuthError{}
client.SetError(err)
if reflect.TypeOf(err) == client.Error {
if reflect.TypeOf(err) == client.Error() {
t.Error("SetError failed")
}

Expand All @@ -476,14 +477,14 @@ func TestClientOptions(t *testing.T) {
client.SetContentLength(true)

client.SetDebug(true)
assertEqual(t, client.Debug, true)
assertEqual(t, client.Debug(), true)

var sl int64 = 1000000
client.SetDebugBodyLimit(sl)
assertEqual(t, client.debugBodySizeLimit, sl)

client.SetAllowGetMethodPayload(true)
assertEqual(t, client.AllowGetMethodPayload, true)
assertEqual(t, client.AllowGetMethodPayload(), true)

client.SetScheme("http")
assertEqual(t, client.scheme, "http")
Expand Down Expand Up @@ -617,31 +618,31 @@ func TestClientNewRequest(t *testing.T) {
func TestClientSetJSONMarshaler(t *testing.T) {
m := func(v interface{}) ([]byte, error) { return nil, nil }
c := New().SetJSONMarshaler(m)
p1 := fmt.Sprintf("%p", c.JSONMarshal)
p1 := fmt.Sprintf("%p", c.JSONMarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetJSONUnmarshaler(t *testing.T) {
m := func([]byte, interface{}) error { return nil }
c := New().SetJSONUnmarshaler(m)
p1 := fmt.Sprintf("%p", c.JSONUnmarshal)
p1 := fmt.Sprintf("%p", c.JSONUnmarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetXMLMarshaler(t *testing.T) {
m := func(v interface{}) ([]byte, error) { return nil, nil }
c := New().SetXMLMarshaler(m)
p1 := fmt.Sprintf("%p", c.XMLMarshal)
p1 := fmt.Sprintf("%p", c.XMLMarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetXMLUnmarshaler(t *testing.T) {
m := func([]byte, interface{}) error { return nil }
c := New().SetXMLUnmarshaler(m)
p1 := fmt.Sprintf("%p", c.XMLUnmarshal)
p1 := fmt.Sprintf("%p", c.XMLUnmarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}
Expand Down Expand Up @@ -1145,23 +1146,21 @@ func TestClone(t *testing.T) {
parent.SetBaseURL("http://localhost")

// set an interface field
parent.UserInfo = &User{
Username: "parent",
}
parent.SetBasicAuth("parent", "")

clone := parent.Clone()
// update value of non-interface type - change will only happen on clone
clone.SetBaseURL("https://local.host")
// update value of interface type - change will also happen on parent
clone.UserInfo.Username = "clone"
clone.BasicAuth().Username = "clone"

// asert non-interface type
assertEqual(t, "http://localhost", parent.BaseURL)
assertEqual(t, "https://local.host", clone.BaseURL)
assertEqual(t, "http://localhost", parent.BaseURL())
assertEqual(t, "https://local.host", clone.BaseURL())

// assert interface type
assertEqual(t, "clone", parent.UserInfo.Username)
assertEqual(t, "clone", clone.UserInfo.Username)
assertEqual(t, "clone", parent.BasicAuth().Username)
assertEqual(t, "clone", clone.BasicAuth().Username)
}

func TestResponseBodyLimit(t *testing.T) {
Expand All @@ -1172,7 +1171,7 @@ func TestResponseBodyLimit(t *testing.T) {

t.Run("Client body limit", func(t *testing.T) {
c := dc().SetResponseBodyLimit(1024)

assertEqual(t, 1024, c.ResponseBodyLimit())
_, err := c.R().Get(ts.URL + "/")
assertNotNil(t, err)
assertEqual(t, err, ErrResponseBodyTooLarge)
Expand Down
54 changes: 27 additions & 27 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ const debugRequestLogKey = "__restyDebugRequestLog"
//_______________________________________________________________________

func parseRequestURL(c *Client, r *Request) error {
if l := len(c.PathParams) + len(c.RawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
if l := len(c.pathParams) + len(c.rawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
params := make(map[string]string, l)

// GitHub #103 Path Params
for p, v := range r.PathParams {
params[p] = url.PathEscape(v)
}
for p, v := range c.PathParams {
for p, v := range c.pathParams {
if _, ok := params[p]; !ok {
params[p] = url.PathEscape(v)
}
Expand All @@ -46,7 +46,7 @@ func parseRequestURL(c *Client, r *Request) error {
params[p] = v
}
}
for p, v := range c.RawPathParams {
for p, v := range c.rawPathParams {
if _, ok := params[p]; !ok {
params[p] = v
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func parseRequestURL(c *Client, r *Request) error {
r.URL = "/" + r.URL
}

reqURL, err = url.Parse(c.BaseURL + r.URL)
reqURL, err = url.Parse(c.baseURL + r.URL)
if err != nil {
return err
}
Expand All @@ -126,8 +126,8 @@ func parseRequestURL(c *Client, r *Request) error {
}

// Adding Query Param
if len(c.QueryParam)+len(r.QueryParam) > 0 {
for k, v := range c.QueryParam {
if len(c.queryParam)+len(r.QueryParam) > 0 {
for k, v := range c.queryParam {
// skip query parameter if it was set in request
if _, ok := r.QueryParam[k]; ok {
continue
Expand Down Expand Up @@ -155,7 +155,7 @@ func parseRequestURL(c *Client, r *Request) error {
}

func parseRequestHeader(c *Client, r *Request) error {
for k, v := range c.Header {
for k, v := range c.header {
if _, ok := r.Header[k]; ok {
continue
}
Expand All @@ -174,13 +174,13 @@ func parseRequestHeader(c *Client, r *Request) error {
}

func parseRequestBody(c *Client, r *Request) error {
if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
if isPayloadSupported(r.Method, c.allowGetMethodPayload) {
switch {
case r.isMultiPart: // Handling Multipart
if err := handleMultipart(c, r); err != nil {
return err
}
case len(c.FormData) > 0 || len(r.FormData) > 0: // Handling Form Data
case len(c.formData) > 0 || len(r.FormData) > 0: // Handling Form Data
handleFormData(c, r)
case r.Body != nil: // Handling Request body
handleContentType(c, r)
Expand All @@ -205,7 +205,7 @@ func parseRequestBody(c *Client, r *Request) error {

func createHTTPRequest(c *Client, r *Request) (err error) {
if r.bodyBuf == nil {
if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.allowGetMethodPayload) {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
} else if c.setContentLength || r.setContentLength {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
Expand All @@ -229,7 +229,7 @@ func createHTTPRequest(c *Client, r *Request) (err error) {
r.RawRequest.Header = r.Header

// Add cookies from client instance into http request
for _, cookie := range c.Cookies {
for _, cookie := range c.cookies {
r.RawRequest.AddCookie(cookie)
}

Expand Down Expand Up @@ -271,32 +271,32 @@ func addCredentials(c *Client, r *Request) error {
if r.UserInfo != nil { // takes precedence
r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
isBasicAuth = true
} else if c.UserInfo != nil {
r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
} else if c.userInfo != nil {
r.RawRequest.SetBasicAuth(c.userInfo.Username, c.userInfo.Password)
isBasicAuth = true
}

if !c.DisableWarn {
if !c.disableWarn {
if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
r.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
}
}

// Set the Authorization Header Scheme
// Set the Authorization header Scheme
var authScheme string
if !IsStringEmpty(r.AuthScheme) {
authScheme = r.AuthScheme
} else if !IsStringEmpty(c.AuthScheme) {
authScheme = c.AuthScheme
} else if !IsStringEmpty(c.authScheme) {
authScheme = c.authScheme
} else {
authScheme = "Bearer"
}

// Build the Token Auth header
// Build the token Auth header
if !IsStringEmpty(r.Token) { // takes precedence
r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+r.Token)
} else if !IsStringEmpty(c.Token) {
r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+c.Token)
r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+r.Token)
} else if !IsStringEmpty(c.token) {
r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+c.token)
}

return nil
Expand Down Expand Up @@ -401,11 +401,11 @@ func parseResponseBody(c *Client, res *Response) (err error) {
}
}

// HTTP status code > 399, considered as Error
// HTTP status code > 399, considered as error
if res.IsError() {
// global error interface
if res.Request.Error == nil && c.Error != nil {
res.Request.Error = reflect.New(c.Error).Interface()
if res.Request.Error == nil && c.error != nil {
res.Request.Error = reflect.New(c.error).Interface()
}

if res.Request.Error != nil {
Expand All @@ -431,7 +431,7 @@ func handleMultipart(c *Client, r *Request) error {
}
}

for k, v := range c.FormData {
for k, v := range c.formData {
for _, iv := range v {
if err := w.WriteField(k, iv); err != nil {
return err
Expand Down Expand Up @@ -472,7 +472,7 @@ func handleMultipart(c *Client, r *Request) error {
}

func handleFormData(c *Client, r *Request) {
for k, v := range c.FormData {
for k, v := range c.formData {
if _, ok := r.FormData[k]; ok {
continue
}
Expand Down Expand Up @@ -520,7 +520,7 @@ func handleRequestBody(c *Client, r *Request) error {
if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
r.bodyBuf, err = jsonMarshal(c, r, r.Body)
} else if IsXMLType(contentType) && (kind == reflect.Struct) {
bodyBytes, err = c.XMLMarshal(r.Body)
bodyBytes, err = c.xmlMarshal(r.Body)
}
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func Test_parseRequestURL(t *testing.T) {
r.SetPathParams(map[string]string{
"foo": "4/5",
}).SetRawPathParams(map[string]string{
"foo": "4/5", // ignored, because the PathParams takes precedence over the RawPathParams
"foo": "4/5", // ignored, because the pathParams takes precedence over the rawPathParams
"bar": "6/7",
})
r.URL = "https://example.com/{foo}/{bar}"
Expand Down Expand Up @@ -182,7 +182,7 @@ func Test_parseRequestURL(t *testing.T) {
{
name: "using deprecated HostURL with relative path in request URL",
init: func(c *Client, r *Request) {
c.BaseURL = "https://example.com"
c.SetBaseURL("https://example.com")
r.URL = "foo/bar"
},
expectedURL: "https://example.com/foo/bar",
Expand Down
Loading

0 comments on commit f8a7343

Please sign in to comment.