Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use context timeout instead of http client timeout option #545 #922

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ type Client struct {
disableWarn bool
allowMethodGetPayload bool
allowMethodDeletePayload bool
timeout time.Duration
retryCount int
retryWaitTime time.Duration
retryMaxWaitTime time.Duration
Expand Down Expand Up @@ -612,6 +613,7 @@ func (c *Client) R() *Request {
Cookies: make([]*http.Cookie, 0),
PathParams: make(map[string]string),
RawPathParams: make(map[string]string),
Timeout: c.timeout,
Debug: c.debug,
IsTrace: c.isTrace,
AuthScheme: c.authScheme,
Expand Down Expand Up @@ -1122,13 +1124,24 @@ func (c *Client) SetContentLength(l bool) *Client {
return c
}

// SetTimeout method sets the timeout for a request raised by the client.
// Timeout method returns the timeout duration value from the client
func (c *Client) Timeout() time.Duration {
c.lock.RLock()
defer c.lock.RUnlock()
return c.timeout
}

// SetTimeout method is used to set a timeout for a request raised by the client.
//
// client.SetTimeout(1 * time.Minute)
//
// It can be overridden at the request level. See [Request.SetTimeout]
//
// client.SetTimeout(time.Duration(1 * time.Minute))
// NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client.Timeout]
func (c *Client) SetTimeout(timeout time.Duration) *Client {
c.lock.Lock()
defer c.lock.Unlock()
c.httpClient.Timeout = timeout
c.timeout = timeout
return c
}

Expand Down Expand Up @@ -2077,7 +2090,7 @@ func (c *Client) execute(req *Request) (*Response, error) {
requestDebugLogger(c, req)

req.Time = time.Now()
resp, err := c.Client().Do(req.RawRequest)
resp, err := c.Client().Do(req.withTimeout())

response := &Response{Request: req, RawResponse: resp}
response.setReceivedAt()
Expand Down
5 changes: 2 additions & 3 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,9 @@ func TestClientTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

c := dcnl().SetTimeout(time.Millisecond * 200)
c := dcnl().SetTimeout(200 * time.Millisecond)
_, err := c.R().Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, strings.Contains(err.Error(), "Client.Timeout"))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestClientTimeoutWithinThreshold(t *testing.T) {
Expand Down
4 changes: 1 addition & 3 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"context"
"errors"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -212,8 +211,7 @@ func TestClientRetryWithSetContext(t *testing.T) {

assertNotNil(t, ts)
assertNotNil(t, err)
assertEqual(t, true, (strings.HasPrefix(err.Error(), "Get "+ts.URL+"/") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/\"")))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestRequestContext(t *testing.T) {
Expand Down
37 changes: 35 additions & 2 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type Request struct {
AllowMethodGetPayload bool
AllowMethodDeletePayload bool
IsDone bool
Timeout time.Duration
RetryCount int
RetryWaitTime time.Duration
RetryMaxWaitTime time.Duration
Expand All @@ -84,6 +85,7 @@ type Request struct {
isSaveResponse bool
jsonEscapeHTML bool
ctx context.Context
ctxCancelFunc context.CancelFunc
values map[string]any
client *Client
bodyBuf *bytes.Buffer
Expand Down Expand Up @@ -894,6 +896,18 @@ func (r *Request) SetCookies(rs []*http.Cookie) *Request {
return r
}

// SetTimeout method is used to set a timeout for the current request
//
// client.R().SetTimeout(1 * time.Minute)
//
// It overrides the timeout set at the client instance level, See [Client.SetTimeout]
//
// NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client.Timeout]
func (r *Request) SetTimeout(timeout time.Duration) *Request {
r.Timeout = timeout
return r
}

// SetLogger method sets given writer for logging Resty request and response details.
// By default, requests and responses inherit their logger from the client.
//
Expand Down Expand Up @@ -1273,8 +1287,14 @@ func (r *Request) Execute(method, url string) (res *Response, err error) {
break
}
if r.Context().Err() != nil {
err = wrapErrors(r.Context().Err(), err)
break
if r.ctxCancelFunc != nil {
r.ctxCancelFunc()
r.ctxCancelFunc = nil
}
if !errors.Is(err, context.DeadlineExceeded) {
err = wrapErrors(r.Context().Err(), err)
break
}
}
}

Expand Down Expand Up @@ -1425,6 +1445,7 @@ func (r *Request) Clone(ctx context.Context) *Request {
rr.initTraceIfEnabled()
r.values = make(map[string]any)
r.multipartErrChan = nil
r.ctxCancelFunc = nil

// copy bodyBuf
if r.bodyBuf != nil {
Expand Down Expand Up @@ -1634,6 +1655,18 @@ func (r *Request) isIdempotent() bool {
return found || r.AllowNonIdempotentRetry
}

func (r *Request) withTimeout() *http.Request {
if _, found := r.Context().Deadline(); found {
return r.RawRequest
}
if r.Timeout > 0 {
ctx, ctxCancelFunc := context.WithTimeout(r.Context(), r.Timeout)
r.ctxCancelFunc = ctxCancelFunc
return r.RawRequest.WithContext(ctx)
}
return r.RawRequest
}

func jsonIndent(v []byte) []byte {
buf := acquireBuffer()
defer releaseBuffer(buf)
Expand Down
41 changes: 41 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,47 @@ func TestRequestNoRetryOnNonIdempotentMethod(t *testing.T) {
assertEqual(t, 500, resp.StatusCode())
}

func TestRequestContextTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

t.Run("use client set timeout", func(t *testing.T) {
c := dcnl().SetTimeout(200 * time.Millisecond)
assertEqual(t, true, c.Timeout() > 0)

req := c.R()
assertEqual(t, true, req.Timeout > 0)

_, err := req.Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
})

t.Run("use request set timeout", func(t *testing.T) {
c := dcnl()
assertEqual(t, true, c.Timeout() == 0)

_, err := c.R().
SetTimeout(200 * time.Millisecond).
Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
})

t.Run("use external context for timeout", func(t *testing.T) {
ctx, ctxCancelFunc := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer ctxCancelFunc()

c := dcnl()
_, err := c.R().
SetContext(ctx).
Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
})

}

func TestRequestPanicContext(t *testing.T) {
defer func() {
if r := recover(); r == nil {
Expand Down
23 changes: 8 additions & 15 deletions retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ func TestConditionalGetRequestLevel(t *testing.T) {
logResponse(t, resp)
}

func TestClientRetryGet(t *testing.T) {
func TestClientRetryGetWithTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

c := dcnl().
SetTimeout(time.Millisecond * 50).
SetTimeout(50 * time.Millisecond).
SetRetryCount(3)

resp, err := c.R().Get(ts.URL + "/set-retrycount-test")
Expand All @@ -99,9 +99,7 @@ func TestClientRetryGet(t *testing.T) {
assertEqual(t, 0, resp.StatusCode())
assertEqual(t, 0, len(resp.Cookies()))
assertEqual(t, 0, len(resp.Header()))

assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\""))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestClientRetryWithMinAndMaxWaitTime(t *testing.T) {
Expand Down Expand Up @@ -561,7 +559,7 @@ func TestClientRetryCountWithTimeout(t *testing.T) {
attempt := 0

c := dcnl().
SetTimeout(time.Millisecond * 50).
SetTimeout(50 * time.Millisecond).
SetRetryCount(1).
AddRetryCondition(
func(r *Response, _ error) bool {
Expand All @@ -576,12 +574,8 @@ func TestClientRetryCountWithTimeout(t *testing.T) {
assertEqual(t, 0, resp.StatusCode())
assertEqual(t, 0, len(resp.Cookies()))
assertEqual(t, 0, len(resp.Header()))

// 2 attempts were made
assertEqual(t, 2, resp.Request.Attempt)

assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\""))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestClientRetryTooManyRequestsAndRecover(t *testing.T) {
Expand All @@ -596,6 +590,7 @@ func TestClientRetryTooManyRequestsAndRecover(t *testing.T) {
SetHeader(hdrContentTypeKey, "application/json; charset=utf-8").
SetJSONEscapeHTML(false).
SetResult(AuthSuccess{}).
SetTimeout(10 * time.Millisecond).
Get(ts.URL + "/set-retry-error-recover")

assertError(t, err)
Expand All @@ -608,7 +603,7 @@ func TestClientRetryTooManyRequestsAndRecover(t *testing.T) {
assertNil(t, resp.Error())
}

func TestClientRetryHook(t *testing.T) {
func TestClientRetryHookWithTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

Expand Down Expand Up @@ -641,9 +636,7 @@ func TestClientRetryHook(t *testing.T) {

assertEqual(t, retryCount+1, resp.Request.Attempt)
assertEqual(t, 3, hookCalledCount)

assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\""))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

var errSeekFailure = fmt.Errorf("failing seek test")
Expand Down