Skip to content

Commit

Permalink
feat: use context timeout instead of http client timeout option #545 (#…
Browse files Browse the repository at this point in the history
…922)

- add Request.SetTimeout method
  • Loading branch information
jeevatkm authored Nov 26, 2024
1 parent 8b24a96 commit 48a1a59
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 27 deletions.
21 changes: 17 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ type Client struct {
disableWarn bool
allowMethodGetPayload bool
allowMethodDeletePayload bool
timeout time.Duration
retryCount int
retryWaitTime time.Duration
retryMaxWaitTime time.Duration
Expand Down Expand Up @@ -612,6 +613,7 @@ func (c *Client) R() *Request {
Cookies: make([]*http.Cookie, 0),
PathParams: make(map[string]string),
RawPathParams: make(map[string]string),
Timeout: c.timeout,
Debug: c.debug,
IsTrace: c.isTrace,
AuthScheme: c.authScheme,
Expand Down Expand Up @@ -1122,13 +1124,24 @@ func (c *Client) SetContentLength(l bool) *Client {
return c
}

// SetTimeout method sets the timeout for a request raised by the client.
// Timeout method returns the timeout duration value from the client
func (c *Client) Timeout() time.Duration {
c.lock.RLock()
defer c.lock.RUnlock()
return c.timeout
}

// SetTimeout method is used to set a timeout for a request raised by the client.
//
// client.SetTimeout(1 * time.Minute)
//
// It can be overridden at the request level. See [Request.SetTimeout]
//
// client.SetTimeout(time.Duration(1 * time.Minute))
// NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client.Timeout]
func (c *Client) SetTimeout(timeout time.Duration) *Client {
c.lock.Lock()
defer c.lock.Unlock()
c.httpClient.Timeout = timeout
c.timeout = timeout
return c
}

Expand Down Expand Up @@ -2077,7 +2090,7 @@ func (c *Client) execute(req *Request) (*Response, error) {
requestDebugLogger(c, req)

req.Time = time.Now()
resp, err := c.Client().Do(req.RawRequest)
resp, err := c.Client().Do(req.withTimeout())

response := &Response{Request: req, RawResponse: resp}
response.setReceivedAt()
Expand Down
5 changes: 2 additions & 3 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,9 @@ func TestClientTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

c := dcnl().SetTimeout(time.Millisecond * 200)
c := dcnl().SetTimeout(200 * time.Millisecond)
_, err := c.R().Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, strings.Contains(err.Error(), "Client.Timeout"))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestClientTimeoutWithinThreshold(t *testing.T) {
Expand Down
4 changes: 1 addition & 3 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"context"
"errors"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -212,8 +211,7 @@ func TestClientRetryWithSetContext(t *testing.T) {

assertNotNil(t, ts)
assertNotNil(t, err)
assertEqual(t, true, (strings.HasPrefix(err.Error(), "Get "+ts.URL+"/") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/\"")))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestRequestContext(t *testing.T) {
Expand Down
37 changes: 35 additions & 2 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type Request struct {
AllowMethodGetPayload bool
AllowMethodDeletePayload bool
IsDone bool
Timeout time.Duration
RetryCount int
RetryWaitTime time.Duration
RetryMaxWaitTime time.Duration
Expand All @@ -84,6 +85,7 @@ type Request struct {
isSaveResponse bool
jsonEscapeHTML bool
ctx context.Context
ctxCancelFunc context.CancelFunc
values map[string]any
client *Client
bodyBuf *bytes.Buffer
Expand Down Expand Up @@ -894,6 +896,18 @@ func (r *Request) SetCookies(rs []*http.Cookie) *Request {
return r
}

// SetTimeout method is used to set a timeout for the current request
//
// client.R().SetTimeout(1 * time.Minute)
//
// It overrides the timeout set at the client instance level, See [Client.SetTimeout]
//
// NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client.Timeout]
func (r *Request) SetTimeout(timeout time.Duration) *Request {
r.Timeout = timeout
return r
}

// SetLogger method sets given writer for logging Resty request and response details.
// By default, requests and responses inherit their logger from the client.
//
Expand Down Expand Up @@ -1273,8 +1287,14 @@ func (r *Request) Execute(method, url string) (res *Response, err error) {
break
}
if r.Context().Err() != nil {
err = wrapErrors(r.Context().Err(), err)
break
if r.ctxCancelFunc != nil {
r.ctxCancelFunc()
r.ctxCancelFunc = nil
}
if !errors.Is(err, context.DeadlineExceeded) {
err = wrapErrors(r.Context().Err(), err)
break
}
}
}

Expand Down Expand Up @@ -1425,6 +1445,7 @@ func (r *Request) Clone(ctx context.Context) *Request {
rr.initTraceIfEnabled()
r.values = make(map[string]any)
r.multipartErrChan = nil
r.ctxCancelFunc = nil

// copy bodyBuf
if r.bodyBuf != nil {
Expand Down Expand Up @@ -1634,6 +1655,18 @@ func (r *Request) isIdempotent() bool {
return found || r.AllowNonIdempotentRetry
}

func (r *Request) withTimeout() *http.Request {
if _, found := r.Context().Deadline(); found {
return r.RawRequest
}
if r.Timeout > 0 {
ctx, ctxCancelFunc := context.WithTimeout(r.Context(), r.Timeout)
r.ctxCancelFunc = ctxCancelFunc
return r.RawRequest.WithContext(ctx)
}
return r.RawRequest
}

func jsonIndent(v []byte) []byte {
buf := acquireBuffer()
defer releaseBuffer(buf)
Expand Down
41 changes: 41 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,47 @@ func TestRequestNoRetryOnNonIdempotentMethod(t *testing.T) {
assertEqual(t, 500, resp.StatusCode())
}

func TestRequestContextTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

t.Run("use client set timeout", func(t *testing.T) {
c := dcnl().SetTimeout(200 * time.Millisecond)
assertEqual(t, true, c.Timeout() > 0)

req := c.R()
assertEqual(t, true, req.Timeout > 0)

_, err := req.Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
})

t.Run("use request set timeout", func(t *testing.T) {
c := dcnl()
assertEqual(t, true, c.Timeout() == 0)

_, err := c.R().
SetTimeout(200 * time.Millisecond).
Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
})

t.Run("use external context for timeout", func(t *testing.T) {
ctx, ctxCancelFunc := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer ctxCancelFunc()

c := dcnl()
_, err := c.R().
SetContext(ctx).
Get(ts.URL + "/set-timeout-test")

assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
})

}

func TestRequestPanicContext(t *testing.T) {
defer func() {
if r := recover(); r == nil {
Expand Down
23 changes: 8 additions & 15 deletions retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ func TestConditionalGetRequestLevel(t *testing.T) {
logResponse(t, resp)
}

func TestClientRetryGet(t *testing.T) {
func TestClientRetryGetWithTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

c := dcnl().
SetTimeout(time.Millisecond * 50).
SetTimeout(50 * time.Millisecond).
SetRetryCount(3)

resp, err := c.R().Get(ts.URL + "/set-retrycount-test")
Expand All @@ -99,9 +99,7 @@ func TestClientRetryGet(t *testing.T) {
assertEqual(t, 0, resp.StatusCode())
assertEqual(t, 0, len(resp.Cookies()))
assertEqual(t, 0, len(resp.Header()))

assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\""))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestClientRetryWithMinAndMaxWaitTime(t *testing.T) {
Expand Down Expand Up @@ -561,7 +559,7 @@ func TestClientRetryCountWithTimeout(t *testing.T) {
attempt := 0

c := dcnl().
SetTimeout(time.Millisecond * 50).
SetTimeout(50 * time.Millisecond).
SetRetryCount(1).
AddRetryCondition(
func(r *Response, _ error) bool {
Expand All @@ -576,12 +574,8 @@ func TestClientRetryCountWithTimeout(t *testing.T) {
assertEqual(t, 0, resp.StatusCode())
assertEqual(t, 0, len(resp.Cookies()))
assertEqual(t, 0, len(resp.Header()))

// 2 attempts were made
assertEqual(t, 2, resp.Request.Attempt)

assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\""))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

func TestClientRetryTooManyRequestsAndRecover(t *testing.T) {
Expand All @@ -596,6 +590,7 @@ func TestClientRetryTooManyRequestsAndRecover(t *testing.T) {
SetHeader(hdrContentTypeKey, "application/json; charset=utf-8").
SetJSONEscapeHTML(false).
SetResult(AuthSuccess{}).
SetTimeout(10 * time.Millisecond).
Get(ts.URL + "/set-retry-error-recover")

assertError(t, err)
Expand All @@ -608,7 +603,7 @@ func TestClientRetryTooManyRequestsAndRecover(t *testing.T) {
assertNil(t, resp.Error())
}

func TestClientRetryHook(t *testing.T) {
func TestClientRetryHookWithTimeout(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

Expand Down Expand Up @@ -641,9 +636,7 @@ func TestClientRetryHook(t *testing.T) {

assertEqual(t, retryCount+1, resp.Request.Attempt)
assertEqual(t, 3, hookCalledCount)

assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\""))
assertEqual(t, true, errors.Is(err, context.DeadlineExceeded))
}

var errSeekFailure = fmt.Errorf("failing seek test")
Expand Down

0 comments on commit 48a1a59

Please sign in to comment.