Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(redirect)!: update error type on no redirect policy and cleanup #893 #900

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1131,23 +1131,18 @@ func (c *Client) newErrorInterface() any {
// SetRedirectPolicy method sets the redirect policy for the client. Resty provides ready-to-use
// redirect policies. Wanna create one for yourself, refer to `redirect.go`.
//
// client.SetRedirectPolicy(FlexibleRedirectPolicy(20))
// client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20))
//
// // Need multiple redirect policies together
// client.SetRedirectPolicy(FlexibleRedirectPolicy(20), DomainCheckRedirectPolicy("host1.com", "host2.net"))
func (c *Client) SetRedirectPolicy(policies ...any) *Client {
for _, p := range policies {
if _, ok := p.(RedirectPolicy); !ok {
c.log.Errorf("%v does not implement resty.RedirectPolicy (missing Apply method)",
functionName(p))
}
}

// client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20), resty.DomainCheckRedirectPolicy("host1.com", "host2.net"))
//
// NOTE: It overwrites the previous redirect policies in the client instance.
func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client {
c.lock.Lock()
defer c.lock.Unlock()
c.httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
for _, p := range policies {
if err := p.(RedirectPolicy).Apply(req, via); err != nil {
if err := p.Apply(req, via); err != nil {
return err
}
}
Expand Down
18 changes: 10 additions & 8 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,21 @@ func TestClientRedirectPolicy(t *testing.T) {
ts := createRedirectServer(t)
defer ts.Close()

c := dcnl().SetRedirectPolicy(FlexibleRedirectPolicy(20))
_, err := c.R().Get(ts.URL + "/redirect-1")
c := dcnl().SetRedirectPolicy(FlexibleRedirectPolicy(20), DomainCheckRedirectPolicy("127.0.0.1"))
_, err := c.R().
SetHeader("Name1", "Value1").
SetHeader("Name2", "Value2").
SetHeader("Name3", "Value3").
Get(ts.URL + "/redirect-1")

assertEqual(t, true, (err.Error() == "Get /redirect-21: stopped after 20 redirects" ||
err.Error() == "Get \"/redirect-21\": stopped after 20 redirects"))

c.SetRedirectPolicy(NoRedirectPolicy())
_, err = c.R().Get(ts.URL + "/redirect-1")
assertEqual(t, true, (err.Error() == "Get /redirect-2: resty: auto redirect is disabled" ||
err.Error() == "Get \"/redirect-2\": resty: auto redirect is disabled"))
res, err := c.R().Get(ts.URL + "/redirect-1")
assertNil(t, err)
assertEqual(t, http.StatusTemporaryRedirect, res.StatusCode())
assertEqual(t, `<a href="/redirect-2">Temporary Redirect</a>.`, res.String())
}

func TestClientTimeout(t *testing.T) {
Expand Down Expand Up @@ -485,9 +490,6 @@ func TestClientSettingsCoverage(t *testing.T) {
c.SetAuthToken(authToken)
assertEqual(t, authToken, c.AuthToken())

type brokenRedirectPolicy struct{}
c.SetRedirectPolicy(&brokenRedirectPolicy{})

c.SetCloseConnection(true)

c.DisableDebug()
Expand Down
19 changes: 8 additions & 11 deletions redirect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2015-2024 Jeevanandam M ([email protected]), All rights reserved.
// Copyright (c) 2015-present Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
// SPDX-License-Identifier: MIT

package resty

Expand All @@ -12,8 +13,6 @@ import (
"strings"
)

var ErrAutoRedirectDisabled = errors.New("resty: auto redirect is disabled")

type (
// RedirectPolicy to regulate the redirects in the Resty client.
// Objects implementing the [RedirectPolicy] interface can be registered as
Expand All @@ -35,12 +34,12 @@ func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error
return f(req, via)
}

// NoRedirectPolicy is used to disable redirects in the Resty client
// NoRedirectPolicy is used to disable the redirects in the Resty client
//
// resty.SetRedirectPolicy(NoRedirectPolicy())
// resty.SetRedirectPolicy(resty.NoRedirectPolicy())
func NoRedirectPolicy() RedirectPolicy {
return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
return ErrAutoRedirectDisabled
return http.ErrUseLastResponse
})
}

Expand All @@ -60,22 +59,20 @@ func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy {
// DomainCheckRedirectPolicy method is convenient for defining domain name redirect rules in Resty clients.
// Redirect is allowed only for the host mentioned in the policy.
//
// resty.SetRedirectPolicy(DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net"))
// resty.SetRedirectPolicy(resty.DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net"))
func DomainCheckRedirectPolicy(hostnames ...string) RedirectPolicy {
hosts := make(map[string]bool)
for _, h := range hostnames {
hosts[strings.ToLower(h)] = true
}

fn := RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
if ok := hosts[getHostname(req.URL.Host)]; !ok {
return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy")
}

checkHostAndAddHeaders(req, via[0])
return nil
})

return fn
}

func getHostname(host string) (hostname string) {
Expand Down