From 5860369bee16f3364b47651cb2b35d0da21fd46a Mon Sep 17 00:00:00 2001 From: TurtleRuss Date: Mon, 19 Aug 2024 17:48:07 +0800 Subject: [PATCH 1/4] feat: add mutex to make Client thread-safe --- client.go | 455 +++++++++++++++++++++++++++++++++++++++------ client_test.go | 62 +++--- middleware.go | 54 +++--- middleware_test.go | 4 +- request.go | 16 +- resty.go | 27 +-- retry.go | 2 +- retry_test.go | 2 +- util.go | 6 +- 9 files changed, 481 insertions(+), 147 deletions(-) diff --git a/client.go b/client.go index 20292a7b..4900f3eb 100644 --- a/client.go +++ b/client.go @@ -116,35 +116,35 @@ type ClientTimeoutSetting struct { // Resty also provides an options to override most of the client settings // at request level. type Client struct { - BaseURL string - QueryParam url.Values - FormData url.Values - PathParams map[string]string - RawPathParams map[string]string - Header http.Header - UserInfo *User - Token string - AuthScheme string - Cookies []*http.Cookie - Error reflect.Type - Debug bool - DisableWarn bool - AllowGetMethodPayload bool - RetryCount int - RetryWaitTime time.Duration - RetryMaxWaitTime time.Duration - RetryConditions []RetryConditionFunc - RetryHooks []OnRetryFunc - RetryAfter RetryAfterFunc - RetryResetReaders bool - JSONMarshal func(v interface{}) ([]byte, error) - JSONUnmarshal func(data []byte, v interface{}) error - XMLMarshal func(v interface{}) ([]byte, error) - XMLUnmarshal func(data []byte, v interface{}) error - - // HeaderAuthorizationKey is used to set/access Request Authorization header + baseURL string + queryParam url.Values + formData url.Values + pathParams map[string]string + rawPathParams map[string]string + header http.Header + userInfo *User + token string + authScheme string + cookies []*http.Cookie + error reflect.Type + debug bool + disableWarn bool + allowGetMethodPayload bool + retryCount int + retryWaitTime time.Duration + retryMaxWaitTime time.Duration + retryConditions []RetryConditionFunc + retryHooks []OnRetryFunc + retryAfter RetryAfterFunc + retryResetReaders bool + jsonMarshal func(v interface{}) ([]byte, error) + jsonUnmarshal func(data []byte, v interface{}) error + xmlMarshal func(v interface{}) ([]byte, error) + xmlUnmarshal func(data []byte, v interface{}) error + + // headerAuthorizationKey is used to set/access Request Authorization header // value when `SetAuthToken` option is used. - HeaderAuthorizationKey string + headerAuthorizationKey string jsonEscapeHTML bool setContentLength bool @@ -170,6 +170,7 @@ type Client struct { invalidHooks []ErrorHook panicHooks []ErrorHook rateLimiter RateLimiter + lock *sync.RWMutex } // User type is to hold an username and password information @@ -181,6 +182,13 @@ type User struct { // Client methods //___________________________________ +// BaseURL method is to get Base URL in the client instance. +func (c *Client) BaseURL() string { + c.lock.RLock() + defer c.lock.RUnlock() + return c.baseURL +} + // SetBaseURL method is to set Base URL in the client instance. It will be used with request // raised from this client with relative URL // @@ -192,10 +200,19 @@ type User struct { // // Since v2.7.0 func (c *Client) SetBaseURL(url string) *Client { - c.BaseURL = strings.TrimRight(url, "/") + c.lock.Lock() + defer c.lock.Unlock() + c.baseURL = strings.TrimRight(url, "/") return c } +// Header method gets all header fields and its value in the client instance. +func (c *Client) Header() http.Header { + c.lock.RLock() + defer c.lock.RUnlock() + return c.header +} + // SetHeader method sets a single header field and its value in the client instance. // These headers will be applied to all requests raised from this client instance. // Also it can be overridden at request level header options. @@ -208,7 +225,9 @@ func (c *Client) SetBaseURL(url string) *Client { // SetHeader("Content-Type", "application/json"). // SetHeader("Accept", "application/json") func (c *Client) SetHeader(header, value string) *Client { - c.Header.Set(header, value) + c.lock.Lock() + defer c.lock.Unlock() + c.header.Set(header, value) return c } @@ -225,8 +244,10 @@ func (c *Client) SetHeader(header, value string) *Client { // "Accept": "application/json", // }) func (c *Client) SetHeaders(headers map[string]string) *Client { + c.lock.Lock() + defer c.lock.Unlock() for h, v := range headers { - c.Header.Set(h, v) + c.header.Set(h, v) } return c } @@ -243,7 +264,23 @@ func (c *Client) SetHeaders(headers map[string]string) *Client { // // Since v2.6.0 func (c *Client) SetHeaderVerbatim(header, value string) *Client { - c.Header[header] = []string{value} + c.lock.Lock() + defer c.lock.Unlock() + c.header[header] = []string{value} + return c +} + +// UserInfo method gets the user information in the client instance. +func (c *Client) UserInfo() *User { + c.lock.RLock() + defer c.lock.RUnlock() + return c.userInfo +} + +func (c *Client) SetUserInfo(user *User) *Client { + c.lock.Lock() + defer c.lock.Unlock() + c.userInfo = user return c } @@ -254,10 +291,19 @@ func (c *Client) SetHeaderVerbatim(header, value string) *Client { // // client.SetCookieJar(nil) func (c *Client) SetCookieJar(jar http.CookieJar) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.httpClient.Jar = jar return c } +// Cookies method gets all cookies in the client instance. +func (c *Client) Cookies() []*http.Cookie { + c.lock.RLock() + defer c.lock.RUnlock() + return c.cookies +} + // SetCookie method appends a single cookie in the client instance. // These cookies will be added to all the request raised from this client instance. // @@ -266,7 +312,9 @@ func (c *Client) SetCookieJar(jar http.CookieJar) *Client { // Value:"This is cookie value", // }) func (c *Client) SetCookie(hc *http.Cookie) *Client { - c.Cookies = append(c.Cookies, hc) + c.lock.Lock() + defer c.lock.Unlock() + c.cookies = append(c.cookies, hc) return c } @@ -287,10 +335,19 @@ func (c *Client) SetCookie(hc *http.Cookie) *Client { // // Setting a cookies into resty // client.SetCookies(cookies) func (c *Client) SetCookies(cs []*http.Cookie) *Client { - c.Cookies = append(c.Cookies, cs...) + c.lock.Lock() + defer c.lock.Unlock() + c.cookies = append(c.cookies, cs...) return c } +// QueryParam method gets all parameters and their values in the client instance. +func (c *Client) QueryParam() url.Values { + c.lock.RLock() + defer c.lock.RUnlock() + return c.queryParam +} + // SetQueryParam method sets single parameter and its value in the client instance. // It will be formed as query string for the request. // @@ -304,7 +361,9 @@ func (c *Client) SetCookies(cs []*http.Cookie) *Client { // SetQueryParam("search", "kitchen papers"). // SetQueryParam("size", "large") func (c *Client) SetQueryParam(param, value string) *Client { - c.QueryParam.Set(param, value) + c.lock.Lock() + defer c.lock.Unlock() + c.queryParam.Set(param, value) return c } @@ -322,12 +381,20 @@ func (c *Client) SetQueryParam(param, value string) *Client { // "size": "large", // }) func (c *Client) SetQueryParams(params map[string]string) *Client { + // Do not lock here since there is potential deadlock. for p, v := range params { c.SetQueryParam(p, v) } return c } +// FormData method gets form parameters and their values in the client instance. +func (c *Client) FormData() url.Values { + c.lock.RLock() + defer c.lock.RUnlock() + return c.formData +} + // SetFormData method sets Form parameters and their values in the client instance. // It's applicable only HTTP method `POST` and `PUT` and request content type would be set as // `application/x-www-form-urlencoded`. These form data will be added to all the request raised from @@ -340,12 +407,21 @@ func (c *Client) SetQueryParams(params map[string]string) *Client { // "user_id": "3455454545", // }) func (c *Client) SetFormData(data map[string]string) *Client { + c.lock.Lock() + defer c.lock.Unlock() for k, v := range data { - c.FormData.Set(k, v) + c.formData.Set(k, v) } return c } +// BasicAuth method gets the basic authentication header in the HTTP request. +func (c *Client) BasicAuth() *User { + c.lock.RLock() + defer c.lock.RUnlock() + return c.userInfo +} + // SetBasicAuth method sets the basic authentication header in the HTTP request. For Example: // // Authorization: Basic @@ -359,10 +435,26 @@ func (c *Client) SetFormData(data map[string]string) *Client { // // See `Request.SetBasicAuth`. func (c *Client) SetBasicAuth(username, password string) *Client { - c.UserInfo = &User{Username: username, Password: password} + c.lock.Lock() + defer c.lock.Unlock() + c.userInfo = &User{Username: username, Password: password} return c } +// Token method gets the auth token of the `Authorization` header for all HTTP requests. +func (c *Client) Token() string { + c.lock.RLock() + defer c.lock.RUnlock() + return c.token +} + +// HeaderAuthorizationKey method gets the Header Authorization Key on the Resty client. +func (c *Client) HeaderAuthorizationKey() string { + c.lock.RLock() + defer c.lock.RUnlock() + return c.headerAuthorizationKey +} + // SetAuthToken method sets the auth token of the `Authorization` header for all HTTP requests. // The default auth scheme is `Bearer`, it can be customized with the method `SetAuthScheme`. For Example: // @@ -377,10 +469,19 @@ func (c *Client) SetBasicAuth(username, password string) *Client { // // See `Request.SetAuthToken`. func (c *Client) SetAuthToken(token string) *Client { - c.Token = token + c.lock.Lock() + defer c.lock.Unlock() + c.token = token return c } +// AuthScheme method gets the auth scheme type in the HTTP request. +func (c *Client) AuthScheme() string { + c.lock.RLock() + defer c.lock.RUnlock() + return c.authScheme +} + // SetAuthScheme method sets the auth scheme type in the HTTP request. For Example: // // Authorization: @@ -400,12 +501,14 @@ func (c *Client) SetAuthToken(token string) *Client { // // See `Request.SetAuthToken`. func (c *Client) SetAuthScheme(scheme string) *Client { - c.AuthScheme = scheme + c.lock.Lock() + defer c.lock.Unlock() + c.authScheme = scheme return c } // SetDigestAuth method sets the Digest Access auth scheme for the client. If a server responds with 401 and sends -// a Digest challenge in the WWW-Authenticate Header, requests will be resent with the appropriate Authorization Header. +// a Digest challenge in the WWW-Authenticate header, requests will be resent with the appropriate Authorization header. // // For Example: To set the Digest scheme with user "Mufasa" and password "Circle Of Life" // @@ -417,7 +520,9 @@ func (c *Client) SetAuthScheme(scheme string) *Client { // // See `Request.SetDigestAuth`. func (c *Client) SetDigestAuth(username, password string) *Client { + c.lock.Lock() oldTransport := c.httpClient.Transport + c.lock.Unlock() c.OnBeforeRequest(func(c *Client, _ *Request) error { c.httpClient.Transport = &digestTransport{ digestCredentials: digestCredentials{username, password}, @@ -441,7 +546,7 @@ func (c *Client) R() *Request { Cookies: make([]*http.Cookie, 0), PathParams: map[string]string{}, RawPathParams: map[string]string{}, - Debug: c.Debug, + Debug: c.debug, client: c, multipartFiles: []*File{}, @@ -511,6 +616,8 @@ func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { // Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one // set will be invoked for each call to Request.Execute() that completes. func (c *Client) OnError(h ErrorHook) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.errorHooks = append(c.errorHooks, h) return c } @@ -523,6 +630,8 @@ func (c *Client) OnError(h ErrorHook) *Client { // // Since v2.8.0 func (c *Client) OnSuccess(h SuccessHook) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.successHooks = append(c.successHooks, h) return c } @@ -535,6 +644,8 @@ func (c *Client) OnSuccess(h SuccessHook) *Client { // // Since v2.8.0 func (c *Client) OnInvalid(h ErrorHook) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.invalidHooks = append(c.invalidHooks, h) return c } @@ -549,6 +660,8 @@ func (c *Client) OnInvalid(h ErrorHook) *Client { // // Since v2.8.0 func (c *Client) OnPanic(h ErrorHook) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.panicHooks = append(c.panicHooks, h) return c } @@ -558,6 +671,8 @@ func (c *Client) OnPanic(h ErrorHook) *Client { // // Note: Only one pre-request hook can be registered. Use `client.OnBeforeRequest` for multiple. func (c *Client) SetPreRequestHook(h PreRequestHook) *Client { + c.lock.Lock() + defer c.lock.Unlock() if c.preReqHook != nil { c.log.Warnf("Overwriting an existing pre-request hook: %s", functionName(h)) } @@ -565,6 +680,13 @@ func (c *Client) SetPreRequestHook(h PreRequestHook) *Client { return c } +// Debug method gets if the Resty client is in debug mode. +func (c *Client) Debug() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.debug +} + // SetDebug method enables the debug mode on Resty client. Client logs details of every request and response. // For `Request` it logs information such as HTTP verb, Relative URL path, Host, Headers, Body if it has one. // For `Response` it logs information such as Status, Response Time, Headers, Body if it has one. @@ -573,7 +695,9 @@ func (c *Client) SetPreRequestHook(h PreRequestHook) *Client { // // Also it can be enabled at request level for particular request, see `Request.SetDebug`. func (c *Client) SetDebug(d bool) *Client { - c.Debug = d + c.lock.Lock() + defer c.lock.Unlock() + c.debug = d return c } @@ -581,6 +705,8 @@ func (c *Client) SetDebug(d bool) *Client { // // client.SetDebugBodyLimit(1000000) func (c *Client) SetDebugBodyLimit(sl int64) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.debugBodySizeLimit = sl return c } @@ -588,6 +714,8 @@ func (c *Client) SetDebugBodyLimit(sl int64) *Client { // OnRequestLog method used to set request log callback into Resty. Registered callback gets // called before the resty actually logs the information. func (c *Client) OnRequestLog(rl RequestLogCallback) *Client { + c.lock.Lock() + defer c.lock.Unlock() if c.requestLog != nil { c.log.Warnf("Overwriting an existing on-request-log callback from=%s to=%s", functionName(c.requestLog), functionName(rl)) @@ -599,6 +727,8 @@ func (c *Client) OnRequestLog(rl RequestLogCallback) *Client { // OnResponseLog method used to set response log callback into Resty. Registered callback gets // called before the resty actually logs the information. func (c *Client) OnResponseLog(rl ResponseLogCallback) *Client { + c.lock.Lock() + defer c.lock.Unlock() if c.responseLog != nil { c.log.Warnf("Overwriting an existing on-response-log callback from=%s to=%s", functionName(c.responseLog), functionName(rl)) @@ -607,23 +737,41 @@ func (c *Client) OnResponseLog(rl ResponseLogCallback) *Client { return c } +// DisableWarn method gets if the Resty client disables the warning message. +func (c *Client) DisableWarn() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.disableWarn +} + // SetDisableWarn method disables the warning message on Resty client. // // For Example: Resty warns the user when BasicAuth used on non-TLS mode. // // client.SetDisableWarn(true) func (c *Client) SetDisableWarn(d bool) *Client { - c.DisableWarn = d + c.lock.Lock() + defer c.lock.Unlock() + c.disableWarn = d return c } +// AllowGetMethodPayload method gets if the Resty client allows the GET method with payload. +func (c *Client) AllowGetMethodPayload() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.allowGetMethodPayload +} + // SetAllowGetMethodPayload method allows the GET method with payload on Resty client. // // For Example: Resty allows the user sends request with a payload on HTTP GET method. // // client.SetAllowGetMethodPayload(true) func (c *Client) SetAllowGetMethodPayload(a bool) *Client { - c.AllowGetMethodPayload = a + c.lock.Lock() + defer c.lock.Unlock() + c.allowGetMethodPayload = a return c } @@ -631,6 +779,8 @@ func (c *Client) SetAllowGetMethodPayload(a bool) *Client { // // Compliant to interface `resty.Logger`. func (c *Client) SetLogger(l Logger) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.log = l return c } @@ -642,6 +792,8 @@ func (c *Client) SetLogger(l Logger) *Client { // // Also you have an option to enable for particular request. See `Request.SetContentLength` func (c *Client) SetContentLength(l bool) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.setContentLength = l return c } @@ -650,11 +802,20 @@ func (c *Client) SetContentLength(l bool) *Client { // // client.SetTimeout(time.Duration(1 * time.Minute)) func (c *Client) SetTimeout(timeout time.Duration) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.httpClient.Timeout = timeout return c } -// SetError method is to register the global or client common `Error` object into Resty. +// Error method returns the global or client common `Error` object into Resty. +func (c *Client) Error() reflect.Type { + c.lock.RLock() + defer c.lock.RUnlock() + return c.error +} + +// SetError method is to register the global or client common `error` object into Resty. // It is used for automatic unmarshalling if response status code is greater than 399 and // content type either JSON or XML. Can be pointer or non-pointer. // @@ -662,7 +823,9 @@ func (c *Client) SetTimeout(timeout time.Duration) *Client { // // OR // client.SetError(Error{}) func (c *Client) SetError(err interface{}) *Client { - c.Error = typeOf(err) + c.lock.Lock() + defer c.lock.Unlock() + c.error = typeOf(err) return c } @@ -681,6 +844,8 @@ func (c *Client) SetRedirectPolicy(policies ...interface{}) *Client { } } + c.lock.Lock() + defer c.lock.Unlock() c.httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { for _, p := range policies { if err := p.(RedirectPolicy).Apply(req, via); err != nil { @@ -689,78 +854,156 @@ func (c *Client) SetRedirectPolicy(policies ...interface{}) *Client { } return nil // looks good, go ahead } - return c } +// RetryCount method gets retry count in Resty client. +func (c *Client) RetryCount() int { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryCount +} + // SetRetryCount method enables retry on Resty client and allows you // to set no. of retry count. Resty uses a Backoff mechanism. func (c *Client) SetRetryCount(count int) *Client { - c.RetryCount = count + c.lock.Lock() + defer c.lock.Unlock() + c.retryCount = count return c } +// RetryWaitTime gets default wait time to sleep before retrying requeset. +func (c *Client) RetryWaitTime() time.Duration { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryWaitTime +} + // SetRetryWaitTime method sets default wait time to sleep before retrying // request. // // Default is 100 milliseconds. func (c *Client) SetRetryWaitTime(waitTime time.Duration) *Client { - c.RetryWaitTime = waitTime + c.lock.Lock() + defer c.lock.Unlock() + c.retryWaitTime = waitTime return c } +// RetryMaxWaitTime method gets max wait time to sleep before retrying request. +func (c *Client) RetryMaxWaitTime() time.Duration { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryMaxWaitTime +} + // SetRetryMaxWaitTime method sets max wait time to sleep before retrying // request. // // Default is 2 seconds. func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { - c.RetryMaxWaitTime = maxWaitTime + c.lock.Lock() + defer c.lock.Unlock() + c.retryMaxWaitTime = maxWaitTime return c } +// RetryAfter gets callback to calculate wait time between retries. +func (c *Client) RetryAfter() RetryAfterFunc { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryAfter +} + // SetRetryAfter sets callback to calculate wait time between retries. // Default (nil) implies exponential backoff with jitter func (c *Client) SetRetryAfter(callback RetryAfterFunc) *Client { - c.RetryAfter = callback + c.lock.Lock() + defer c.lock.Unlock() + c.retryAfter = callback return c } +// JSONMarshaler method gets the JSON marshaler function to marshal the request body. +func (c *Client) JSONMarshaler() func(v interface{}) ([]byte, error) { + c.lock.RLock() + defer c.lock.RUnlock() + return c.jsonMarshal +} + // SetJSONMarshaler method sets the JSON marshaler function to marshal the request body. // By default, Resty uses `encoding/json` package to marshal the request body. // // Since v2.8.0 func (c *Client) SetJSONMarshaler(marshaler func(v interface{}) ([]byte, error)) *Client { - c.JSONMarshal = marshaler + c.lock.Lock() + defer c.lock.Unlock() + c.jsonMarshal = marshaler return c } +// JSONUnmarshaler method gets the JSON unmarshaler function to unmarshal the response body. +func (c *Client) JSONUnmarshaler() func([]byte, interface{}) error { + c.lock.RLock() + defer c.lock.RUnlock() + return c.jsonUnmarshal +} + // SetJSONUnmarshaler method sets the JSON unmarshaler function to unmarshal the response body. // By default, Resty uses `encoding/json` package to unmarshal the response body. // // Since v2.8.0 func (c *Client) SetJSONUnmarshaler(unmarshaler func(data []byte, v interface{}) error) *Client { - c.JSONUnmarshal = unmarshaler + c.lock.Lock() + defer c.lock.Unlock() + c.jsonUnmarshal = unmarshaler return c } +// XMLMarshaler method gets the XML marshaler function to marshal the request body. +func (c *Client) XMLMarshaler() func(interface{}) ([]byte, error) { + c.lock.RLock() + defer c.lock.RUnlock() + return c.xmlMarshal +} + // SetXMLMarshaler method sets the XML marshaler function to marshal the request body. // By default, Resty uses `encoding/xml` package to marshal the request body. // // Since v2.8.0 func (c *Client) SetXMLMarshaler(marshaler func(v interface{}) ([]byte, error)) *Client { - c.XMLMarshal = marshaler + c.lock.Lock() + defer c.lock.Unlock() + c.xmlMarshal = marshaler return c } +// XMLUnmarshaler method gets the XML unmarshaler function to unmarshal the response body. +func (c *Client) XMLUnmarshaler() func([]byte, interface{}) error { + c.lock.RLock() + defer c.lock.RUnlock() + return c.xmlUnmarshal +} + // SetXMLUnmarshaler method sets the XML unmarshaler function to unmarshal the response body. // By default, Resty uses `encoding/xml` package to unmarshal the response body. // // Since v2.8.0 func (c *Client) SetXMLUnmarshaler(unmarshaler func(data []byte, v interface{}) error) *Client { - c.XMLUnmarshal = unmarshaler + c.lock.Lock() + defer c.lock.Unlock() + c.xmlUnmarshal = unmarshaler return c } +// RetryConditions method gets all retry condition functions. +func (c *Client) RetryConditions() []RetryConditionFunc { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryConditions +} + // AddRetryCondition method adds a retry condition function to array of functions // that are checked to determine if the request is retried. The request will // retry if any of the functions return true and error is nil. @@ -768,7 +1011,9 @@ func (c *Client) SetXMLUnmarshaler(unmarshaler func(data []byte, v interface{}) // Note: These retry conditions are applied on all Request made using this Client. // For Request specific retry conditions check *Request.AddRetryCondition func (c *Client) AddRetryCondition(condition RetryConditionFunc) *Client { - c.RetryConditions = append(c.RetryConditions, condition) + c.lock.Lock() + defer c.lock.Unlock() + c.retryConditions = append(c.retryConditions, condition) return c } @@ -783,21 +1028,40 @@ func (c *Client) AddRetryAfterErrorCondition() *Client { return c } +// RetryHooks gets all retry hooks. +func (c *Client) RetryHooks() []OnRetryFunc { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryHooks +} + // AddRetryHook adds a side-effecting retry hook to an array of hooks // that will be executed on each retry. // // Since v2.6.0 func (c *Client) AddRetryHook(hook OnRetryFunc) *Client { - c.RetryHooks = append(c.RetryHooks, hook) + c.lock.Lock() + defer c.lock.Unlock() + c.retryHooks = append(c.retryHooks, hook) return c } +// RetryResetReaders method gets if the Resty client is enabled to seek the start +// of all file readers given as multipart files. +func (c *Client) RetryResetReaders() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryResetReaders +} + // SetRetryResetReaders method enables the Resty client to seek the start of all // file readers given as multipart files, if the given object implements `io.ReadSeeker`. // // Since ... func (c *Client) SetRetryResetReaders(b bool) *Client { - c.RetryResetReaders = b + c.lock.Lock() + defer c.lock.Unlock() + c.retryResetReaders = b return c } @@ -818,6 +1082,8 @@ func (c *Client) SetTLSClientConfig(config *tls.Config) *Client { c.log.Errorf("%v", err) return c } + c.lock.Lock() + defer c.lock.Unlock() transport.TLSClientConfig = config return c } @@ -842,6 +1108,8 @@ func (c *Client) SetProxy(proxyURL string) *Client { return c } + c.lock.Lock() + defer c.lock.Unlock() c.proxyURL = pURL transport.Proxy = http.ProxyURL(c.proxyURL) return c @@ -856,6 +1124,9 @@ func (c *Client) RemoveProxy() *Client { c.log.Errorf("%v", err) return c } + + c.lock.Lock() + defer c.lock.Unlock() c.proxyURL = nil transport.Proxy = nil return c @@ -868,6 +1139,9 @@ func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { c.log.Errorf("%v", err) return c } + + c.lock.Lock() + defer c.lock.Unlock() config.Certificates = append(config.Certificates, certs...) return c } @@ -887,6 +1161,9 @@ func (c *Client) SetRootCertificate(pemFilePath string) *Client { c.log.Errorf("%v", err) return c } + + c.lock.Lock() + defer c.lock.Unlock() if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } @@ -904,6 +1181,9 @@ func (c *Client) SetRootCertificateFromString(pemContent string) *Client { c.log.Errorf("%v", err) return c } + + c.lock.Lock() + defer c.lock.Unlock() if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } @@ -918,6 +1198,8 @@ func (c *Client) SetRootCertificateFromString(pemContent string) *Client { // // client.SetOutputDirectory("/save/http/response/here") func (c *Client) SetOutputDirectory(dirPath string) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.outputDirectory = dirPath return c } @@ -927,6 +1209,8 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client { // // Since v2.9.0 func (c *Client) SetRateLimiter(rl RateLimiter) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.rateLimiter = rl return c } @@ -950,6 +1234,8 @@ func (c *Client) SetRateLimiter(rl RateLimiter) *Client { // // client.SetTransport(transport) func (c *Client) SetTransport(transport http.RoundTripper) *Client { + c.lock.Lock() + defer c.lock.Unlock() if transport != nil { c.httpClient.Transport = transport } @@ -960,6 +1246,8 @@ func (c *Client) SetTransport(transport http.RoundTripper) *Client { // // client.SetScheme("http") func (c *Client) SetScheme(scheme string) *Client { + c.lock.Lock() + defer c.lock.Unlock() if !IsStringEmpty(scheme) { c.scheme = strings.TrimSpace(scheme) } @@ -969,6 +1257,8 @@ func (c *Client) SetScheme(scheme string) *Client { // SetCloseConnection method sets variable `Close` in http request struct with the given // value. More info: https://golang.org/src/net/http/request.go func (c *Client) SetCloseConnection(close bool) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.closeConnection = close return c } @@ -980,10 +1270,20 @@ func (c *Client) SetCloseConnection(close bool) *Client { // Note: Response middlewares are not applicable, if you use this option. Basically you have // taken over the control of response parsing from `Resty`. func (c *Client) SetDoNotParseResponse(parse bool) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.notParseResponse = parse return c } +// PathParam method gets single URL path key-value pair in the +// Resty client instance. +func (c *Client) PathParam() map[string]string { + c.lock.RLock() + defer c.lock.RUnlock() + return c.pathParams +} + // SetPathParam method sets single URL path key-value pair in the // Resty client instance. // @@ -999,7 +1299,9 @@ func (c *Client) SetDoNotParseResponse(parse bool) *Client { // Also it can be overridden at request level Path Params options, // see `Request.SetPathParam` or `Request.SetPathParams`. func (c *Client) SetPathParam(param, value string) *Client { - c.PathParams[param] = value + c.lock.Lock() + defer c.lock.Unlock() + c.pathParams[param] = value return c } @@ -1028,6 +1330,14 @@ func (c *Client) SetPathParams(params map[string]string) *Client { return c } +// RawPathParam method gets single URL path key-value pair in the +// Resty client instance. +func (c *Client) RawPathParam() map[string]string { + c.lock.RLock() + defer c.lock.RUnlock() + return c.rawPathParams +} + // SetRawPathParam method sets single URL path key-value pair in the // Resty client instance. // @@ -1051,7 +1361,9 @@ func (c *Client) SetPathParams(params map[string]string) *Client { // // Since v2.8.0 func (c *Client) SetRawPathParam(param, value string) *Client { - c.RawPathParams[param] = value + c.lock.Lock() + defer c.lock.Unlock() + c.rawPathParams[param] = value return c } @@ -1086,6 +1398,8 @@ func (c *Client) SetRawPathParams(params map[string]string) *Client { // // Note: This option only applicable to standard JSON Marshaller. func (c *Client) SetJSONEscapeHTML(b bool) *Client { + c.lock.Lock() + defer c.lock.Unlock() c.jsonEscapeHTML = b return c } @@ -1096,13 +1410,15 @@ func (c *Client) SetJSONEscapeHTML(b bool) *Client { // client := resty.New().EnableTrace() // // resp, err := client.R().Get("https://httpbin.org/get") -// fmt.Println("Error:", err) +// fmt.Println("error:", err) // fmt.Println("Trace Info:", resp.Request.TraceInfo()) // // Also `Request.EnableTrace` available too to get trace info for single request. // // Since v2.0.0 func (c *Client) EnableTrace() *Client { + c.lock.Lock() + defer c.lock.Unlock() c.trace = true return c } @@ -1111,6 +1427,8 @@ func (c *Client) EnableTrace() *Client { // // Since v2.0.0 func (c *Client) DisableTrace() *Client { + c.lock.Lock() + defer c.lock.Unlock() c.trace = false return c } @@ -1118,6 +1436,8 @@ func (c *Client) DisableTrace() *Client { // IsProxySet method returns the true is proxy is set from resty client otherwise // false. By default proxy is set from environment, refer to `http.ProxyFromEnvironment`. func (c *Client) IsProxySet() bool { + c.lock.RLock() + defer c.lock.RUnlock() return c.proxyURL != nil } @@ -1125,6 +1445,8 @@ func (c *Client) IsProxySet() bool { // // Since v1.1.0 func (c *Client) GetClient() *http.Client { + c.lock.RLock() + defer c.lock.RUnlock() return c.httpClient } @@ -1133,6 +1455,8 @@ func (c *Client) GetClient() *http.Client { // // Since v2.8.0 become exported method. func (c *Client) Transport() (*http.Transport, error) { + c.lock.RLock() + defer c.lock.RUnlock() if transport, ok := c.httpClient.Transport.(*http.Transport); ok { return transport, nil } @@ -1155,6 +1479,7 @@ func (c *Client) Clone() *Client { // lock values should not be copied - thus new values are used. cc.afterResponseLock = &sync.RWMutex{} cc.udBeforeRequestLock = &sync.RWMutex{} + cc.lock = &sync.RWMutex{} return &cc } @@ -1271,6 +1596,8 @@ func (c *Client) tlsConfig() (*tls.Config, error) { if err != nil { return nil, err } + c.lock.Lock() + defer c.lock.Unlock() if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{} } @@ -1302,6 +1629,8 @@ func (e *ResponseError) Unwrap() error { // It wraps the error in a ResponseError if the resp is not nil // so hooks can access it. func (c *Client) onErrorHooks(req *Request, resp *Response, err error) { + c.lock.RLock() + defer c.lock.RUnlock() if err != nil { if resp != nil { // wrap with ResponseError err = &ResponseError{Response: resp, Err: err} @@ -1318,6 +1647,8 @@ func (c *Client) onErrorHooks(req *Request, resp *Response, err error) { // Helper to run panicHooks hooks. func (c *Client) onPanicHooks(req *Request, err error) { + c.lock.RLock() + defer c.lock.RUnlock() for _, h := range c.panicHooks { h(req, err) } @@ -1325,6 +1656,8 @@ func (c *Client) onPanicHooks(req *Request, err error) { // Helper to run invalidHooks hooks. func (c *Client) onInvalidHooks(req *Request, err error) { + c.lock.RLock() + defer c.lock.RUnlock() for _, h := range c.invalidHooks { h(req, err) } diff --git a/client_test.go b/client_test.go index faa27103..80932924 100644 --- a/client_test.go +++ b/client_test.go @@ -238,7 +238,7 @@ func TestClientProxy(t *testing.T) { assertNotNil(t, resp) assertNotNil(t, err) - // Error + // error c.SetProxy("//not.a.user@%66%6f%6f.com:8888") resp, err = c.R(). @@ -337,9 +337,9 @@ func TestClientSetHeaderVerbatim(t *testing.T) { SetHeader("header-lowercase", "value_standard") //lint:ignore SA1008 valid one, so ignore this! - unConventionHdrValue := strings.Join(c.Header["header-lowercase"], "") + unConventionHdrValue := strings.Join(c.Header()["header-lowercase"], "") assertEqual(t, "value_lowercase", unConventionHdrValue) - assertEqual(t, "value_standard", c.Header.Get("Header-Lowercase")) + assertEqual(t, "value_standard", c.Header().Get("Header-Lowercase")) } func TestClientSetTransport(t *testing.T) { @@ -385,20 +385,20 @@ func TestClientOptions(t *testing.T) { assertEqual(t, client.setContentLength, true) client.SetBaseURL("http://httpbin.org") - assertEqual(t, "http://httpbin.org", client.BaseURL) + assertEqual(t, "http://httpbin.org", client.BaseURL()) client.SetHeader(hdrContentTypeKey, "application/json; charset=utf-8") client.SetHeaders(map[string]string{ hdrUserAgentKey: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) go-resty v0.1", "X-Request-Id": strconv.FormatInt(time.Now().UnixNano(), 10), }) - assertEqual(t, "application/json; charset=utf-8", client.Header.Get(hdrContentTypeKey)) + assertEqual(t, "application/json; charset=utf-8", client.Header().Get(hdrContentTypeKey)) client.SetCookie(&http.Cookie{ Name: "default-cookie", Value: "This is cookie default-cookie value", }) - assertEqual(t, "default-cookie", client.Cookies[0].Name) + assertEqual(t, "default-cookie", client.Cookies()[0].Name) cookies := []*http.Cookie{ { @@ -410,45 +410,45 @@ func TestClientOptions(t *testing.T) { }, } client.SetCookies(cookies) - assertEqual(t, "default-cookie-1", client.Cookies[1].Name) - assertEqual(t, "default-cookie-2", client.Cookies[2].Name) + assertEqual(t, "default-cookie-1", client.Cookies()[1].Name) + assertEqual(t, "default-cookie-2", client.Cookies()[2].Name) client.SetQueryParam("test_param_1", "Param_1") client.SetQueryParams(map[string]string{"test_param_2": "Param_2", "test_param_3": "Param_3"}) - assertEqual(t, "Param_3", client.QueryParam.Get("test_param_3")) + assertEqual(t, "Param_3", client.QueryParam().Get("test_param_3")) rTime := strconv.FormatInt(time.Now().UnixNano(), 10) client.SetFormData(map[string]string{"r_time": rTime}) - assertEqual(t, rTime, client.FormData.Get("r_time")) + assertEqual(t, rTime, client.FormData().Get("r_time")) client.SetBasicAuth("myuser", "mypass") - assertEqual(t, "myuser", client.UserInfo.Username) + assertEqual(t, "myuser", client.UserInfo().Username) client.SetAuthToken("AC75BD37F019E08FBC594900518B4F7E") - assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token) + assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token()) client.SetDisableWarn(true) - assertEqual(t, client.DisableWarn, true) + assertEqual(t, client.DisableWarn(), true) client.SetRetryCount(3) - assertEqual(t, 3, client.RetryCount) + assertEqual(t, 3, client.RetryCount()) rwt := time.Duration(1000) * time.Millisecond client.SetRetryWaitTime(rwt) - assertEqual(t, rwt, client.RetryWaitTime) + assertEqual(t, rwt, client.RetryWaitTime()) mrwt := time.Duration(2) * time.Second client.SetRetryMaxWaitTime(mrwt) - assertEqual(t, mrwt, client.RetryMaxWaitTime) + assertEqual(t, mrwt, client.RetryMaxWaitTime()) client.AddRetryAfterErrorCondition() - equal(client.RetryConditions[0], func(response *Response, err error) bool { + equal(client.RetryConditions()[0], func(response *Response, err error) bool { return response.IsError() }) err := &AuthError{} client.SetError(err) - if reflect.TypeOf(err) == client.Error { + if reflect.TypeOf(err) == client.Error() { t.Error("SetError failed") } @@ -474,14 +474,14 @@ func TestClientOptions(t *testing.T) { client.SetContentLength(true) client.SetDebug(true) - assertEqual(t, client.Debug, true) + assertEqual(t, client.Debug(), true) var sl int64 = 1000000 client.SetDebugBodyLimit(sl) assertEqual(t, client.debugBodySizeLimit, sl) client.SetAllowGetMethodPayload(true) - assertEqual(t, client.AllowGetMethodPayload, true) + assertEqual(t, client.AllowGetMethodPayload(), true) client.SetScheme("http") assertEqual(t, client.scheme, "http") @@ -615,7 +615,7 @@ func TestClientNewRequest(t *testing.T) { func TestClientSetJSONMarshaler(t *testing.T) { m := func(v interface{}) ([]byte, error) { return nil, nil } c := New().SetJSONMarshaler(m) - p1 := fmt.Sprintf("%p", c.JSONMarshal) + p1 := fmt.Sprintf("%p", c.JSONMarshaler()) p2 := fmt.Sprintf("%p", m) assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers } @@ -623,7 +623,7 @@ func TestClientSetJSONMarshaler(t *testing.T) { func TestClientSetJSONUnmarshaler(t *testing.T) { m := func([]byte, interface{}) error { return nil } c := New().SetJSONUnmarshaler(m) - p1 := fmt.Sprintf("%p", c.JSONUnmarshal) + p1 := fmt.Sprintf("%p", c.JSONUnmarshaler()) p2 := fmt.Sprintf("%p", m) assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers } @@ -631,7 +631,7 @@ func TestClientSetJSONUnmarshaler(t *testing.T) { func TestClientSetXMLMarshaler(t *testing.T) { m := func(v interface{}) ([]byte, error) { return nil, nil } c := New().SetXMLMarshaler(m) - p1 := fmt.Sprintf("%p", c.XMLMarshal) + p1 := fmt.Sprintf("%p", c.XMLMarshaler()) p2 := fmt.Sprintf("%p", m) assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers } @@ -639,7 +639,7 @@ func TestClientSetXMLMarshaler(t *testing.T) { func TestClientSetXMLUnmarshaler(t *testing.T) { m := func([]byte, interface{}) error { return nil } c := New().SetXMLUnmarshaler(m) - p1 := fmt.Sprintf("%p", c.XMLUnmarshal) + p1 := fmt.Sprintf("%p", c.XMLUnmarshaler()) p2 := fmt.Sprintf("%p", m) assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers } @@ -1114,21 +1114,21 @@ func TestClone(t *testing.T) { parent.SetBaseURL("http://localhost") // set an interface field - parent.UserInfo = &User{ + parent.SetUserInfo(&User{ Username: "parent", - } + }) clone := parent.Clone() // update value of non-interface type - change will only happen on clone clone.SetBaseURL("https://local.host") // update value of interface type - change will also happen on parent - clone.UserInfo.Username = "clone" + clone.UserInfo().Username = "clone" // asert non-interface type - assertEqual(t, "http://localhost", parent.BaseURL) - assertEqual(t, "https://local.host", clone.BaseURL) + assertEqual(t, "http://localhost", parent.BaseURL()) + assertEqual(t, "https://local.host", clone.BaseURL()) // assert interface type - assertEqual(t, "clone", parent.UserInfo.Username) - assertEqual(t, "clone", clone.UserInfo.Username) + assertEqual(t, "clone", parent.UserInfo().Username) + assertEqual(t, "clone", clone.UserInfo().Username) } diff --git a/middleware.go b/middleware.go index d116ae0f..b931dc2b 100644 --- a/middleware.go +++ b/middleware.go @@ -27,14 +27,14 @@ const debugRequestLogKey = "__restyDebugRequestLog" //_______________________________________________________________________ func parseRequestURL(c *Client, r *Request) error { - if l := len(c.PathParams) + len(c.RawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 { + if l := len(c.pathParams) + len(c.rawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 { params := make(map[string]string, l) // GitHub #103 Path Params for p, v := range r.PathParams { params[p] = url.PathEscape(v) } - for p, v := range c.PathParams { + for p, v := range c.pathParams { if _, ok := params[p]; !ok { params[p] = url.PathEscape(v) } @@ -46,7 +46,7 @@ func parseRequestURL(c *Client, r *Request) error { params[p] = v } } - for p, v := range c.RawPathParams { + for p, v := range c.rawPathParams { if _, ok := params[p]; !ok { params[p] = v } @@ -114,7 +114,7 @@ func parseRequestURL(c *Client, r *Request) error { r.URL = "/" + r.URL } - reqURL, err = url.Parse(c.BaseURL + r.URL) + reqURL, err = url.Parse(c.baseURL + r.URL) if err != nil { return err } @@ -126,8 +126,8 @@ func parseRequestURL(c *Client, r *Request) error { } // Adding Query Param - if len(c.QueryParam)+len(r.QueryParam) > 0 { - for k, v := range c.QueryParam { + if len(c.queryParam)+len(r.QueryParam) > 0 { + for k, v := range c.queryParam { // skip query parameter if it was set in request if _, ok := r.QueryParam[k]; ok { continue @@ -155,7 +155,7 @@ func parseRequestURL(c *Client, r *Request) error { } func parseRequestHeader(c *Client, r *Request) error { - for k, v := range c.Header { + for k, v := range c.header { if _, ok := r.Header[k]; ok { continue } @@ -174,13 +174,13 @@ func parseRequestHeader(c *Client, r *Request) error { } func parseRequestBody(c *Client, r *Request) error { - if isPayloadSupported(r.Method, c.AllowGetMethodPayload) { + if isPayloadSupported(r.Method, c.allowGetMethodPayload) { switch { case r.isMultiPart: // Handling Multipart if err := handleMultipart(c, r); err != nil { return err } - case len(c.FormData) > 0 || len(r.FormData) > 0: // Handling Form Data + case len(c.formData) > 0 || len(r.FormData) > 0: // Handling Form Data handleFormData(c, r) case r.Body != nil: // Handling Request body handleContentType(c, r) @@ -205,7 +205,7 @@ func parseRequestBody(c *Client, r *Request) error { func createHTTPRequest(c *Client, r *Request) (err error) { if r.bodyBuf == nil { - if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) { + if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.allowGetMethodPayload) { r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader) } else if c.setContentLength || r.setContentLength { r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody) @@ -229,7 +229,7 @@ func createHTTPRequest(c *Client, r *Request) (err error) { r.RawRequest.Header = r.Header // Add cookies from client instance into http request - for _, cookie := range c.Cookies { + for _, cookie := range c.cookies { r.RawRequest.AddCookie(cookie) } @@ -271,32 +271,32 @@ func addCredentials(c *Client, r *Request) error { if r.UserInfo != nil { // takes precedence r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password) isBasicAuth = true - } else if c.UserInfo != nil { - r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password) + } else if c.userInfo != nil { + r.RawRequest.SetBasicAuth(c.userInfo.Username, c.userInfo.Password) isBasicAuth = true } - if !c.DisableWarn { + if !c.disableWarn { if isBasicAuth && !strings.HasPrefix(r.URL, "https") { r.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS") } } - // Set the Authorization Header Scheme + // Set the Authorization header Scheme var authScheme string if !IsStringEmpty(r.AuthScheme) { authScheme = r.AuthScheme - } else if !IsStringEmpty(c.AuthScheme) { - authScheme = c.AuthScheme + } else if !IsStringEmpty(c.authScheme) { + authScheme = c.authScheme } else { authScheme = "Bearer" } - // Build the Token Auth header + // Build the token Auth header if !IsStringEmpty(r.Token) { // takes precedence - r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+r.Token) - } else if !IsStringEmpty(c.Token) { - r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+c.Token) + r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+r.Token) + } else if !IsStringEmpty(c.token) { + r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+c.token) } return nil @@ -389,11 +389,11 @@ func parseResponseBody(c *Client, res *Response) (err error) { } } - // HTTP status code > 399, considered as Error + // HTTP status code > 399, considered as error if res.IsError() { // global error interface - if res.Request.Error == nil && c.Error != nil { - res.Request.Error = reflect.New(c.Error).Interface() + if res.Request.Error == nil && c.error != nil { + res.Request.Error = reflect.New(c.error).Interface() } if res.Request.Error != nil { @@ -412,7 +412,7 @@ func handleMultipart(c *Client, r *Request) error { r.bodyBuf = acquireBuffer() w := multipart.NewWriter(r.bodyBuf) - for k, v := range c.FormData { + for k, v := range c.formData { for _, iv := range v { if err := w.WriteField(k, iv); err != nil { return err @@ -453,7 +453,7 @@ func handleMultipart(c *Client, r *Request) error { } func handleFormData(c *Client, r *Request) { - for k, v := range c.FormData { + for k, v := range c.formData { if _, ok := r.FormData[k]; ok { continue } @@ -501,7 +501,7 @@ func handleRequestBody(c *Client, r *Request) error { if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) { r.bodyBuf, err = jsonMarshal(c, r, r.Body) } else if IsXMLType(contentType) && (kind == reflect.Struct) { - bodyBytes, err = c.XMLMarshal(r.Body) + bodyBytes, err = c.xmlMarshal(r.Body) } if err != nil { return err diff --git a/middleware_test.go b/middleware_test.go index d585221c..694d5a71 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -98,7 +98,7 @@ func Test_parseRequestURL(t *testing.T) { r.SetPathParams(map[string]string{ "foo": "4/5", }).SetRawPathParams(map[string]string{ - "foo": "4/5", // ignored, because the PathParams takes precedence over the RawPathParams + "foo": "4/5", // ignored, because the pathParams takes precedence over the rawPathParams "bar": "6/7", }) r.URL = "https://example.com/{foo}/{bar}" @@ -182,7 +182,7 @@ func Test_parseRequestURL(t *testing.T) { { name: "using deprecated HostURL with relative path in request URL", init: func(c *Client, r *Request) { - c.BaseURL = "https://example.com" + c.SetBaseURL("https://example.com") r.URL = "foo/bar" }, expectedURL: "https://example.com/foo/bar", diff --git a/request.go b/request.go index 4e13ff09..5f4e25d3 100644 --- a/request.go +++ b/request.go @@ -931,7 +931,7 @@ func (r *Request) Execute(method, url string) (*Response, error) { r.Method = method r.URL = r.selectAddr(addrs, url, 0) - if r.client.RetryCount == 0 { + if r.client.retryCount == 0 { r.Attempt = 1 resp, err = r.client.execute(r) r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err)) @@ -951,12 +951,12 @@ func (r *Request) Execute(method, url string) (*Response, error) { return resp, err }, - Retries(r.client.RetryCount), - WaitTime(r.client.RetryWaitTime), - MaxWaitTime(r.client.RetryMaxWaitTime), - RetryConditions(append(r.retryConditions, r.client.RetryConditions...)), - RetryHooks(r.client.RetryHooks), - ResetMultipartReaders(r.client.RetryResetReaders), + Retries(r.client.retryCount), + WaitTime(r.client.retryWaitTime), + MaxWaitTime(r.client.retryMaxWaitTime), + RetryConditions(append(r.retryConditions, r.client.retryConditions...)), + RetryHooks(r.client.retryHooks), + ResetMultipartReaders(r.client.retryResetReaders), ) if err != nil { @@ -984,7 +984,7 @@ type SRVRecord struct { func (r *Request) fmtBodyString(sl int64) (body string) { body = "***** NO CONTENT *****" - if !isPayloadSupported(r.Method, r.client.AllowGetMethodPayload) { + if !isPayloadSupported(r.Method, r.client.allowGetMethodPayload) { return } diff --git a/resty.go b/resty.go index 7ceaba86..f2ebc619 100644 --- a/resty.go +++ b/resty.go @@ -123,25 +123,26 @@ func createCookieJar() *cookiejar.Jar { func createClient(hc *http.Client) *Client { c := &Client{ // not setting language default values - QueryParam: url.Values{}, - FormData: url.Values{}, - Header: http.Header{}, - Cookies: make([]*http.Cookie, 0), - RetryWaitTime: defaultWaitTime, - RetryMaxWaitTime: defaultMaxWaitTime, - PathParams: make(map[string]string), - RawPathParams: make(map[string]string), - JSONMarshal: json.Marshal, - JSONUnmarshal: json.Unmarshal, - XMLMarshal: xml.Marshal, - XMLUnmarshal: xml.Unmarshal, - HeaderAuthorizationKey: http.CanonicalHeaderKey("Authorization"), + queryParam: url.Values{}, + formData: url.Values{}, + header: http.Header{}, + cookies: make([]*http.Cookie, 0), + retryWaitTime: defaultWaitTime, + retryMaxWaitTime: defaultMaxWaitTime, + pathParams: make(map[string]string), + rawPathParams: make(map[string]string), + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, + headerAuthorizationKey: http.CanonicalHeaderKey("Authorization"), jsonEscapeHTML: true, httpClient: hc, debugBodySizeLimit: math.MaxInt32, udBeforeRequestLock: &sync.RWMutex{}, afterResponseLock: &sync.RWMutex{}, + lock: &sync.RWMutex{}, } // Logger diff --git a/retry.go b/retry.go index c5eda26b..ebb49d51 100644 --- a/retry.go +++ b/retry.go @@ -178,7 +178,7 @@ func sleepDuration(resp *Response, min, max time.Duration, attempt int) (time.Du return jitterBackoff(min, max, attempt), nil } - retryAfterFunc := resp.Request.client.RetryAfter + retryAfterFunc := resp.Request.client.RetryAfter() // Check for custom callback if retryAfterFunc == nil { diff --git a/retry_test.go b/retry_test.go index 8d58cc16..72fa77b1 100644 --- a/retry_test.go +++ b/retry_test.go @@ -47,7 +47,7 @@ func TestBackoffNoWaitForLastRetry(t *testing.T) { Request: &Request{ ctx: canceledCtx, client: &Client{ - RetryAfter: func(*Client, *Response) (time.Duration, error) { + retryAfter: func(*Client, *Response) (time.Duration, error) { return 6, nil }, }, diff --git a/util.go b/util.go index 5a69e4fc..7054d6fe 100644 --- a/util.go +++ b/util.go @@ -117,9 +117,9 @@ func IsXMLType(ct string) bool { // Unmarshalc content into object from JSON or XML func Unmarshalc(c *Client, ct string, b []byte, d interface{}) (err error) { if IsJSONType(ct) { - err = c.JSONUnmarshal(b, d) + err = c.jsonUnmarshal(b, d) } else if IsXMLType(ct) { - err = c.XMLUnmarshal(b, d) + err = c.xmlUnmarshal(b, d) } return @@ -155,7 +155,7 @@ func jsonMarshal(c *Client, r *Request, d interface{}) (*bytes.Buffer, error) { return noescapeJSONMarshal(d) } - data, err := c.JSONMarshal(d) + data, err := c.jsonMarshal(d) if err != nil { return nil, err } From b4251987a274a21bb3510f911234ebc98f83d030 Mon Sep 17 00:00:00 2001 From: TurtleRuss0 Date: Fri, 30 Aug 2024 02:14:34 +0000 Subject: [PATCH 2/4] test: use New() method instead of creating an empty instance --- retry_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/retry_test.go b/retry_test.go index 72fa77b1..c2385771 100644 --- a/retry_test.go +++ b/retry_test.go @@ -43,14 +43,14 @@ func TestBackoffNoWaitForLastRetry(t *testing.T) { canceledCtx, cancel := context.WithCancel(context.Background()) defer cancel() + client := New() + client.SetRetryAfter(func(*Client, *Response) (time.Duration, error) { + return 6, nil + }) resp := &Response{ Request: &Request{ - ctx: canceledCtx, - client: &Client{ - retryAfter: func(*Client, *Response) (time.Duration, error) { - return 6, nil - }, - }, + ctx: canceledCtx, + client: client, }, } From 333f8de5acfbb2d5376fcd62e20f9f7953a6dd48 Mon Sep 17 00:00:00 2001 From: TurtleRuss Date: Sun, 1 Sep 2024 17:57:37 +0800 Subject: [PATCH 3/4] test: add assertions and remove duplicated method. --- client.go | 18 ++---------------- client_test.go | 13 ++++++------- request_test.go | 5 +++++ retry_test.go | 21 ++++++++++++++++----- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/client.go b/client.go index 4900f3eb..b5af3330 100644 --- a/client.go +++ b/client.go @@ -270,20 +270,6 @@ func (c *Client) SetHeaderVerbatim(header, value string) *Client { return c } -// UserInfo method gets the user information in the client instance. -func (c *Client) UserInfo() *User { - c.lock.RLock() - defer c.lock.RUnlock() - return c.userInfo -} - -func (c *Client) SetUserInfo(user *User) *Client { - c.lock.Lock() - defer c.lock.Unlock() - c.userInfo = user - return c -} - // SetCookieJar method sets custom http.CookieJar in the resty client. Its way to override default. // // For Example: sometimes we don't want to save cookies in api contacting, we can remove the default @@ -1330,9 +1316,9 @@ func (c *Client) SetPathParams(params map[string]string) *Client { return c } -// RawPathParam method gets single URL path key-value pair in the +// RawPathParams method gets single URL path key-value pair in the // Resty client instance. -func (c *Client) RawPathParam() map[string]string { +func (c *Client) RawPathParams() map[string]string { c.lock.RLock() defer c.lock.RUnlock() return c.rawPathParams diff --git a/client_test.go b/client_test.go index 80932924..db593301 100644 --- a/client_test.go +++ b/client_test.go @@ -75,6 +75,7 @@ func TestClientAuthScheme(t *testing.T) { // Ensure setting the scheme works as well c.SetAuthScheme("Bearer") + assertEqual(t, "Bearer", c.AuthScheme()) resp2, err2 := c.R().Get("/profile") assertError(t, err2) @@ -422,7 +423,7 @@ func TestClientOptions(t *testing.T) { assertEqual(t, rTime, client.FormData().Get("r_time")) client.SetBasicAuth("myuser", "mypass") - assertEqual(t, "myuser", client.UserInfo().Username) + assertEqual(t, "myuser", client.BasicAuth().Username) client.SetAuthToken("AC75BD37F019E08FBC594900518B4F7E") assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token()) @@ -1114,21 +1115,19 @@ func TestClone(t *testing.T) { parent.SetBaseURL("http://localhost") // set an interface field - parent.SetUserInfo(&User{ - Username: "parent", - }) + parent.SetBasicAuth("parent", "") clone := parent.Clone() // update value of non-interface type - change will only happen on clone clone.SetBaseURL("https://local.host") // update value of interface type - change will also happen on parent - clone.UserInfo().Username = "clone" + clone.BasicAuth().Username = "clone" // asert non-interface type assertEqual(t, "http://localhost", parent.BaseURL()) assertEqual(t, "https://local.host", clone.BaseURL()) // assert interface type - assertEqual(t, "clone", parent.UserInfo().Username) - assertEqual(t, "clone", clone.UserInfo().Username) + assertEqual(t, "clone", parent.BasicAuth().Username) + assertEqual(t, "clone", clone.BasicAuth().Username) } diff --git a/request_test.go b/request_test.go index 01e07844..05e2bb8a 100644 --- a/request_test.go +++ b/request_test.go @@ -1717,6 +1717,8 @@ func TestGetPathParamAndPathParams(t *testing.T) { SetBaseURL(ts.URL). SetPathParam("userId", "sample@sample.com") + assertEqual(t, "sample@sample.com", c.PathParam()["userId"]) + resp, err := c.R().SetPathParam("subAccountId", "100002"). Get("/v1/users/{userId}/{subAccountId}/details") @@ -1907,6 +1909,9 @@ func TestRawPathParamURLInput(t *testing.T) { "path": "users/developers", }) + assertEqual(t, "sample@sample.com", c.RawPathParams()["userId"]) + assertEqual(t, "users/developers", c.RawPathParams()["path"]) + resp, err := c.R(). SetRawPathParams(map[string]string{ "subAccountId": "100002", diff --git a/retry_test.go b/retry_test.go index c2385771..d6603a7d 100644 --- a/retry_test.go +++ b/retry_test.go @@ -724,14 +724,21 @@ func TestClientRetryHook(t *testing.T) { attempt := 0 + retryHook := func(r *Response, _ error) { + attempt++ + } + c := dc(). SetRetryCount(2). SetTimeout(time.Second * 3). - AddRetryHook( - func(r *Response, _ error) { - attempt++ - }, - ) + AddRetryHook(retryHook) + + // Since reflect.DeepEqual can not compare two functions + // just compare pointers of the two hooks + originHookPointer := reflect.ValueOf(retryHook).Pointer() + getterHookPointer := reflect.ValueOf(c.RetryHooks()[0]).Pointer() + + assertEqual(t, originHookPointer, getterHookPointer) resp, err := c.R().Get(ts.URL + "/set-retrycount-test") assertEqual(t, "", resp.Status()) @@ -783,6 +790,8 @@ func TestResetMultipartReaderSeekStartError(t *testing.T) { SetRetryResetReaders(true). AddRetryAfterErrorCondition() + assertEqual(t, true, c.RetryResetReaders()) + resp, err := c.R(). SetFileReader("name", "filename", testSeeker). Post(ts.URL + "/set-reset-multipart-readers-test") @@ -816,6 +825,8 @@ func TestResetMultipartReaders(t *testing.T) { }, ) + assertEqual(t, true, c.RetryResetReaders()) + resp, err := c.R(). SetFileReader("name", "filename", bufReader). Post(ts.URL + "/set-reset-multipart-readers-test") From abf737104aa7e67394d994dc78b4e2a25f2ab279 Mon Sep 17 00:00:00 2001 From: TurtleRuss Date: Mon, 2 Sep 2024 10:49:45 +0800 Subject: [PATCH 4/4] feat: make Client.ResponseBodyLimit private and thread-safe --- client.go | 16 ++++++++++++---- client_test.go | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 9c31bedf..d7d593ce 100644 --- a/client.go +++ b/client.go @@ -179,8 +179,7 @@ type Client struct { // value when `SetAuthToken` option is used. headerAuthorizationKey string - ResponseBodyLimit int - + responseBodyLimit int jsonEscapeHTML bool setContentLength bool @@ -575,7 +574,7 @@ func (c *Client) R() *Request { multipartFields: []*MultipartField{}, jsonEscapeHTML: c.jsonEscapeHTML, log: c.log, - responseBodyLimit: c.ResponseBodyLimit, + responseBodyLimit: c.responseBodyLimit, } return r } @@ -1427,6 +1426,13 @@ func (c *Client) SetJSONEscapeHTML(b bool) *Client { return c } +// ResponseBodyLimit gets the max body size limit on response. +func (c *Client) ResponseBodyLimit() int { + c.lock.RLock() + defer c.lock.RUnlock() + return c.responseBodyLimit +} + // SetResponseBodyLimit set a max body size limit on response, avoid reading too many data to memory. // // Client will return [resty.ErrResponseBodyTooLarge] if uncompressed response body size if larger than limit. @@ -1437,7 +1443,9 @@ func (c *Client) SetJSONEscapeHTML(b bool) *Client { // // this can be overridden at client level with [Request.SetResponseBodyLimit] func (c *Client) SetResponseBodyLimit(v int) *Client { - c.ResponseBodyLimit = v + c.lock.Lock() + defer c.lock.Unlock() + c.responseBodyLimit = v return c } diff --git a/client_test.go b/client_test.go index 5b21ef01..f07d986e 100644 --- a/client_test.go +++ b/client_test.go @@ -1171,7 +1171,7 @@ func TestResponseBodyLimit(t *testing.T) { t.Run("Client body limit", func(t *testing.T) { c := dc().SetResponseBodyLimit(1024) - + assertEqual(t, 1024, c.ResponseBodyLimit()) _, err := c.R().Get(ts.URL + "/") assertNotNil(t, err) assertEqual(t, err, ErrResponseBodyTooLarge)