diff --git a/client.go b/client.go index 94ed14f0..f42f36a9 100644 --- a/client.go +++ b/client.go @@ -1131,23 +1131,18 @@ func (c *Client) newErrorInterface() any { // SetRedirectPolicy method sets the redirect policy for the client. Resty provides ready-to-use // redirect policies. Wanna create one for yourself, refer to `redirect.go`. // -// client.SetRedirectPolicy(FlexibleRedirectPolicy(20)) +// client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20)) // // // Need multiple redirect policies together -// client.SetRedirectPolicy(FlexibleRedirectPolicy(20), DomainCheckRedirectPolicy("host1.com", "host2.net")) -func (c *Client) SetRedirectPolicy(policies ...any) *Client { - for _, p := range policies { - if _, ok := p.(RedirectPolicy); !ok { - c.log.Errorf("%v does not implement resty.RedirectPolicy (missing Apply method)", - functionName(p)) - } - } - +// client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20), resty.DomainCheckRedirectPolicy("host1.com", "host2.net")) +// +// NOTE: It overwrites the previous redirect policies in the client instance. +func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { c.lock.Lock() defer c.lock.Unlock() c.httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { for _, p := range policies { - if err := p.(RedirectPolicy).Apply(req, via); err != nil { + if err := p.Apply(req, via); err != nil { return err } } diff --git a/client_test.go b/client_test.go index e66b84bd..86ee49e1 100644 --- a/client_test.go +++ b/client_test.go @@ -188,16 +188,21 @@ func TestClientRedirectPolicy(t *testing.T) { ts := createRedirectServer(t) defer ts.Close() - c := dcnl().SetRedirectPolicy(FlexibleRedirectPolicy(20)) - _, err := c.R().Get(ts.URL + "/redirect-1") + c := dcnl().SetRedirectPolicy(FlexibleRedirectPolicy(20), DomainCheckRedirectPolicy("127.0.0.1")) + _, err := c.R(). + SetHeader("Name1", "Value1"). + SetHeader("Name2", "Value2"). + SetHeader("Name3", "Value3"). + Get(ts.URL + "/redirect-1") assertEqual(t, true, (err.Error() == "Get /redirect-21: stopped after 20 redirects" || err.Error() == "Get \"/redirect-21\": stopped after 20 redirects")) c.SetRedirectPolicy(NoRedirectPolicy()) - _, err = c.R().Get(ts.URL + "/redirect-1") - assertEqual(t, true, (err.Error() == "Get /redirect-2: resty: auto redirect is disabled" || - err.Error() == "Get \"/redirect-2\": resty: auto redirect is disabled")) + res, err := c.R().Get(ts.URL + "/redirect-1") + assertNil(t, err) + assertEqual(t, http.StatusTemporaryRedirect, res.StatusCode()) + assertEqual(t, `Temporary Redirect.`, res.String()) } func TestClientTimeout(t *testing.T) { @@ -485,9 +490,6 @@ func TestClientSettingsCoverage(t *testing.T) { c.SetAuthToken(authToken) assertEqual(t, authToken, c.AuthToken()) - type brokenRedirectPolicy struct{} - c.SetRedirectPolicy(&brokenRedirectPolicy{}) - c.SetCloseConnection(true) c.DisableDebug() diff --git a/redirect.go b/redirect.go index fb37f38d..5c95d101 100644 --- a/redirect.go +++ b/redirect.go @@ -1,6 +1,7 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty @@ -12,8 +13,6 @@ import ( "strings" ) -var ErrAutoRedirectDisabled = errors.New("resty: auto redirect is disabled") - type ( // RedirectPolicy to regulate the redirects in the Resty client. // Objects implementing the [RedirectPolicy] interface can be registered as @@ -35,12 +34,12 @@ func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error return f(req, via) } -// NoRedirectPolicy is used to disable redirects in the Resty client +// NoRedirectPolicy is used to disable the redirects in the Resty client // -// resty.SetRedirectPolicy(NoRedirectPolicy()) +// resty.SetRedirectPolicy(resty.NoRedirectPolicy()) func NoRedirectPolicy() RedirectPolicy { return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { - return ErrAutoRedirectDisabled + return http.ErrUseLastResponse }) } @@ -60,22 +59,20 @@ func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy { // DomainCheckRedirectPolicy method is convenient for defining domain name redirect rules in Resty clients. // Redirect is allowed only for the host mentioned in the policy. // -// resty.SetRedirectPolicy(DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net")) +// resty.SetRedirectPolicy(resty.DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net")) func DomainCheckRedirectPolicy(hostnames ...string) RedirectPolicy { hosts := make(map[string]bool) for _, h := range hostnames { hosts[strings.ToLower(h)] = true } - fn := RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { + return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { if ok := hosts[getHostname(req.URL.Host)]; !ok { return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy") } - + checkHostAndAddHeaders(req, via[0]) return nil }) - - return fn } func getHostname(host string) (hostname string) {