diff --git a/client.go b/client.go index f6f25243..fd174ead 100644 --- a/client.go +++ b/client.go @@ -183,6 +183,7 @@ type Client struct { disableWarn bool allowMethodGetPayload bool allowMethodDeletePayload bool + timeout time.Duration retryCount int retryWaitTime time.Duration retryMaxWaitTime time.Duration @@ -612,6 +613,7 @@ func (c *Client) R() *Request { Cookies: make([]*http.Cookie, 0), PathParams: make(map[string]string), RawPathParams: make(map[string]string), + Timeout: c.timeout, Debug: c.debug, IsTrace: c.isTrace, AuthScheme: c.authScheme, @@ -1122,13 +1124,24 @@ func (c *Client) SetContentLength(l bool) *Client { return c } -// SetTimeout method sets the timeout for a request raised by the client. +// Timeout method returns the timeout duration value from the client +func (c *Client) Timeout() time.Duration { + c.lock.RLock() + defer c.lock.RUnlock() + return c.timeout +} + +// SetTimeout method is used to set a timeout for a request raised by the client. +// +// client.SetTimeout(1 * time.Minute) +// +// It can be overridden at the request level. See [Request.SetTimeout] // -// client.SetTimeout(time.Duration(1 * time.Minute)) +// NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client.Timeout] func (c *Client) SetTimeout(timeout time.Duration) *Client { c.lock.Lock() defer c.lock.Unlock() - c.httpClient.Timeout = timeout + c.timeout = timeout return c } @@ -2077,7 +2090,7 @@ func (c *Client) execute(req *Request) (*Response, error) { requestDebugLogger(c, req) req.Time = time.Now() - resp, err := c.Client().Do(req.RawRequest) + resp, err := c.Client().Do(req.withTimeout()) response := &Response{Request: req, RawResponse: resp} response.setReceivedAt() diff --git a/client_test.go b/client_test.go index a40c6851..f3fddc38 100644 --- a/client_test.go +++ b/client_test.go @@ -135,10 +135,9 @@ func TestClientTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() - c := dcnl().SetTimeout(time.Millisecond * 200) + c := dcnl().SetTimeout(200 * time.Millisecond) _, err := c.R().Get(ts.URL + "/set-timeout-test") - - assertEqual(t, true, strings.Contains(err.Error(), "Client.Timeout")) + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) } func TestClientTimeoutWithinThreshold(t *testing.T) { diff --git a/context_test.go b/context_test.go index 639bbc0d..e9433d4d 100644 --- a/context_test.go +++ b/context_test.go @@ -10,7 +10,6 @@ import ( "context" "errors" "net/http" - "strings" "sync/atomic" "testing" "time" @@ -212,8 +211,7 @@ func TestClientRetryWithSetContext(t *testing.T) { assertNotNil(t, ts) assertNotNil(t, err) - assertEqual(t, true, (strings.HasPrefix(err.Error(), "Get "+ts.URL+"/") || - strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/\""))) + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) } func TestRequestContext(t *testing.T) { diff --git a/request.go b/request.go index 3d759e1d..b490da50 100644 --- a/request.go +++ b/request.go @@ -61,6 +61,7 @@ type Request struct { AllowMethodGetPayload bool AllowMethodDeletePayload bool IsDone bool + Timeout time.Duration RetryCount int RetryWaitTime time.Duration RetryMaxWaitTime time.Duration @@ -84,6 +85,7 @@ type Request struct { isSaveResponse bool jsonEscapeHTML bool ctx context.Context + ctxCancelFunc context.CancelFunc values map[string]any client *Client bodyBuf *bytes.Buffer @@ -894,6 +896,18 @@ func (r *Request) SetCookies(rs []*http.Cookie) *Request { return r } +// SetTimeout method is used to set a timeout for the current request +// +// client.R().SetTimeout(1 * time.Minute) +// +// It overrides the timeout set at the client instance level, See [Client.SetTimeout] +// +// NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client.Timeout] +func (r *Request) SetTimeout(timeout time.Duration) *Request { + r.Timeout = timeout + return r +} + // SetLogger method sets given writer for logging Resty request and response details. // By default, requests and responses inherit their logger from the client. // @@ -1273,8 +1287,14 @@ func (r *Request) Execute(method, url string) (res *Response, err error) { break } if r.Context().Err() != nil { - err = wrapErrors(r.Context().Err(), err) - break + if r.ctxCancelFunc != nil { + r.ctxCancelFunc() + r.ctxCancelFunc = nil + } + if !errors.Is(err, context.DeadlineExceeded) { + err = wrapErrors(r.Context().Err(), err) + break + } } } @@ -1425,6 +1445,7 @@ func (r *Request) Clone(ctx context.Context) *Request { rr.initTraceIfEnabled() r.values = make(map[string]any) r.multipartErrChan = nil + r.ctxCancelFunc = nil // copy bodyBuf if r.bodyBuf != nil { @@ -1634,6 +1655,18 @@ func (r *Request) isIdempotent() bool { return found || r.AllowNonIdempotentRetry } +func (r *Request) withTimeout() *http.Request { + if _, found := r.Context().Deadline(); found { + return r.RawRequest + } + if r.Timeout > 0 { + ctx, ctxCancelFunc := context.WithTimeout(r.Context(), r.Timeout) + r.ctxCancelFunc = ctxCancelFunc + return r.RawRequest.WithContext(ctx) + } + return r.RawRequest +} + func jsonIndent(v []byte) []byte { buf := acquireBuffer() defer releaseBuffer(buf) diff --git a/request_test.go b/request_test.go index eaa17312..3f0202d0 100644 --- a/request_test.go +++ b/request_test.go @@ -2110,6 +2110,47 @@ func TestRequestNoRetryOnNonIdempotentMethod(t *testing.T) { assertEqual(t, 500, resp.StatusCode()) } +func TestRequestContextTimeout(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + t.Run("use client set timeout", func(t *testing.T) { + c := dcnl().SetTimeout(200 * time.Millisecond) + assertEqual(t, true, c.Timeout() > 0) + + req := c.R() + assertEqual(t, true, req.Timeout > 0) + + _, err := req.Get(ts.URL + "/set-timeout-test") + + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) + }) + + t.Run("use request set timeout", func(t *testing.T) { + c := dcnl() + assertEqual(t, true, c.Timeout() == 0) + + _, err := c.R(). + SetTimeout(200 * time.Millisecond). + Get(ts.URL + "/set-timeout-test") + + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) + }) + + t.Run("use external context for timeout", func(t *testing.T) { + ctx, ctxCancelFunc := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer ctxCancelFunc() + + c := dcnl() + _, err := c.R(). + SetContext(ctx). + Get(ts.URL + "/set-timeout-test") + + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) + }) + +} + func TestRequestPanicContext(t *testing.T) { defer func() { if r := recover(); r == nil { diff --git a/retry_test.go b/retry_test.go index 96448a54..eaa04214 100644 --- a/retry_test.go +++ b/retry_test.go @@ -85,12 +85,12 @@ func TestConditionalGetRequestLevel(t *testing.T) { logResponse(t, resp) } -func TestClientRetryGet(t *testing.T) { +func TestClientRetryGetWithTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). - SetTimeout(time.Millisecond * 50). + SetTimeout(50 * time.Millisecond). SetRetryCount(3) resp, err := c.R().Get(ts.URL + "/set-retrycount-test") @@ -99,9 +99,7 @@ func TestClientRetryGet(t *testing.T) { assertEqual(t, 0, resp.StatusCode()) assertEqual(t, 0, len(resp.Cookies())) assertEqual(t, 0, len(resp.Header())) - - assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") || - strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\"")) + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) } func TestClientRetryWithMinAndMaxWaitTime(t *testing.T) { @@ -561,7 +559,7 @@ func TestClientRetryCountWithTimeout(t *testing.T) { attempt := 0 c := dcnl(). - SetTimeout(time.Millisecond * 50). + SetTimeout(50 * time.Millisecond). SetRetryCount(1). AddRetryCondition( func(r *Response, _ error) bool { @@ -576,12 +574,8 @@ func TestClientRetryCountWithTimeout(t *testing.T) { assertEqual(t, 0, resp.StatusCode()) assertEqual(t, 0, len(resp.Cookies())) assertEqual(t, 0, len(resp.Header())) - - // 2 attempts were made assertEqual(t, 2, resp.Request.Attempt) - - assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") || - strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\"")) + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) } func TestClientRetryTooManyRequestsAndRecover(t *testing.T) { @@ -596,6 +590,7 @@ func TestClientRetryTooManyRequestsAndRecover(t *testing.T) { SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetJSONEscapeHTML(false). SetResult(AuthSuccess{}). + SetTimeout(10 * time.Millisecond). Get(ts.URL + "/set-retry-error-recover") assertError(t, err) @@ -608,7 +603,7 @@ func TestClientRetryTooManyRequestsAndRecover(t *testing.T) { assertNil(t, resp.Error()) } -func TestClientRetryHook(t *testing.T) { +func TestClientRetryHookWithTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() @@ -641,9 +636,7 @@ func TestClientRetryHook(t *testing.T) { assertEqual(t, retryCount+1, resp.Request.Attempt) assertEqual(t, 3, hookCalledCount) - - assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") || - strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\"")) + assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) } var errSeekFailure = fmt.Errorf("failing seek test")