Skip to content

Commit

Permalink
feat!: enhance credentials usage and update warning message for crede…
Browse files Browse the repository at this point in the history
…ntials requests (#915)

- User type become unexported 
- HTTP not secure warning message added for auth token flow too
- Intialize the default auth scheme value `Bearer` during client creation
  • Loading branch information
jeevatkm authored Nov 19, 2024
1 parent 0f4a16a commit 3b97332
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 84 deletions.
35 changes: 8 additions & 27 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ var (
jsonKey = "json"
xmlKey = "xml"

defaultAuthScheme = "Bearer"

hdrUserAgentValue = "go-resty/" + Version + " (https://github.com/go-resty/resty)"
bufPool = &sync.Pool{New: func() any { return &bytes.Buffer{} }}
)
Expand Down Expand Up @@ -172,7 +174,7 @@ type Client struct {
pathParams map[string]string
rawPathParams map[string]string
header http.Header
userInfo *User
credentials *credentials
authToken string
authScheme string
cookies []*http.Cookie
Expand Down Expand Up @@ -221,25 +223,13 @@ type Client struct {
certWatcherStopChan chan bool
}

// User type is to hold an username and password information
type User struct {
Username, Password string
}

// CertWatcherOptions allows configuring a watcher that reloads dynamically TLS certs.
type CertWatcherOptions struct {
// PoolInterval is the frequency at which resty will check if the PEM file needs to be reloaded.
// Default is 24 hours.
PoolInterval time.Duration
}

// Clone method returns deep copy of u.
func (u *User) Clone() *User {
uu := new(User)
*uu = *u
return uu
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// Client methods
//___________________________________
Expand Down Expand Up @@ -498,15 +488,6 @@ func (c *Client) SetFormData(data map[string]string) *Client {
return c
}

// UserInfo method returns the authorization username and password.
//
// userInfo := client.UserInfo()
func (c *Client) UserInfo() *User {
c.lock.RLock()
defer c.lock.RUnlock()
return c.userInfo
}

// SetBasicAuth method sets the basic authentication header in the HTTP request. For Example:
//
// Authorization: Basic <base64-encoded-value>
Expand All @@ -522,7 +503,7 @@ func (c *Client) UserInfo() *User {
func (c *Client) SetBasicAuth(username, password string) *Client {
c.lock.Lock()
defer c.lock.Unlock()
c.userInfo = &User{Username: username, Password: password}
c.credentials = &credentials{Username: username, Password: password}
return c
}

Expand Down Expand Up @@ -611,8 +592,8 @@ func (c *Client) SetDigestAuth(username, password string) *Client {
c.lock.Unlock()
c.AddRequestMiddleware(func(c *Client, _ *Request) error {
c.httpClient.Transport = &digestTransport{
digestCredentials: digestCredentials{username, password},
transport: oldTransport,
credentials: credentials{username, password},
transport: oldTransport,
}
return nil
})
Expand All @@ -638,7 +619,6 @@ func (c *Client) R() *Request {
IsTrace: c.isTrace,
AuthScheme: c.authScheme,
AuthToken: c.authToken,
UserInfo: c.userInfo,
RetryCount: c.retryCount,
RetryWaitTime: c.retryWaitTime,
RetryMaxWaitTime: c.retryMaxWaitTime,
Expand All @@ -660,6 +640,7 @@ func (c *Client) R() *Request {
setContentLength: c.setContentLength,
generateCurlOnDebug: c.generateCurlOnDebug,
unescapeQueryParams: c.unescapeQueryParams,
credentials: c.credentials,
}

if c.ctx != nil {
Expand Down Expand Up @@ -2012,7 +1993,7 @@ func (c *Client) Clone(ctx context.Context) *Client {
cc.header = c.header.Clone()
cc.pathParams = maps.Clone(c.pathParams)
cc.rawPathParams = maps.Clone(c.rawPathParams)
cc.userInfo = c.userInfo.Clone()
cc.credentials = c.credentials.Clone()
cc.contentTypeEncoders = maps.Clone(c.contentTypeEncoders)
cc.contentTypeDecoders = maps.Clone(c.contentTypeDecoders)
cc.contentDecompressors = maps.Clone(c.contentDecompressors)
Expand Down
4 changes: 2 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1381,8 +1381,8 @@ func TestClientClone(t *testing.T) {
// assert non-interface type
assertEqual(t, "http://localhost", parent.BaseURL())
assertEqual(t, "https://local.host", clone.BaseURL())
assertEqual(t, "parent", parent.UserInfo().Username)
assertEqual(t, "clone", clone.UserInfo().Username)
assertEqual(t, "parent", parent.credentials.Username)
assertEqual(t, "clone", clone.credentials.Username)

// assert interface/pointer type
assertEqual(t, parent.Client(), clone.Client())
Expand Down
30 changes: 13 additions & 17 deletions digest.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,8 @@ var hashFuncs = map[string]func() hash.Hash{
"SHA-512-256-sess": sha512.New,
}

type digestCredentials struct {
username, password string
}

type digestTransport struct {
digestCredentials
credentials
transport http.RoundTripper
}

Expand Down Expand Up @@ -98,9 +94,9 @@ func (dt *digestTransport) RoundTrip(req *http.Request) (*http.Response, error)
return dt.transport.RoundTrip(req2)
}

func (dt *digestTransport) newCredentials(req *http.Request, c *challenge) *credentials {
return &credentials{
username: dt.username,
func (dt *digestTransport) newCredentials(req *http.Request, c *challenge) *digestCredentials {
return &digestCredentials{
username: dt.Username,
userhash: c.userhash,
realm: c.realm,
nonce: c.nonce,
Expand All @@ -111,7 +107,7 @@ func (dt *digestTransport) newCredentials(req *http.Request, c *challenge) *cred
messageQop: c.qop,
nc: 0,
method: req.Method,
password: dt.password,
password: dt.Password,
}
}

Expand Down Expand Up @@ -203,7 +199,7 @@ func parseChallenge(input string) (*challenge, error) {
return c, nil
}

type credentials struct {
type digestCredentials struct {
username string
userhash string
realm string
Expand All @@ -219,7 +215,7 @@ type credentials struct {
password string
}

func (c *credentials) authorize() (string, error) {
func (c *digestCredentials) authorize() (string, error) {
if _, ok := hashFuncs[c.algorithm]; !ok {
return "", ErrDigestAlgNotSupported
}
Expand Down Expand Up @@ -257,7 +253,7 @@ func (c *credentials) authorize() (string, error) {
return fmt.Sprintf("Digest %s", strings.Join(sl, ", ")), nil
}

func (c *credentials) validateQop() error {
func (c *digestCredentials) validateQop() error {
// Currently only supporting auth quality of protection. TODO: add auth-int support
// NOTE: cURL support auth-int qop for requests other than POST and PUT (i.e. w/o body) by hashing an empty string
// is this applicable for resty? see: https://github.com/curl/curl/blob/307b7543ea1e73ab04e062bdbe4b5bb409eaba3a/lib/vauth/digest.c#L774
Expand All @@ -282,14 +278,14 @@ func (c *credentials) validateQop() error {
return nil
}

func (c *credentials) h(data string) string {
func (c *digestCredentials) h(data string) string {
hfCtor := hashFuncs[c.algorithm]
hf := hfCtor()
_, _ = hf.Write([]byte(data)) // Hash.Write never returns an error
return fmt.Sprintf("%x", hf.Sum(nil))
}

func (c *credentials) resp() (string, error) {
func (c *digestCredentials) resp() (string, error) {
c.nc++

b := make([]byte, 16)
Expand All @@ -306,12 +302,12 @@ func (c *credentials) resp() (string, error) {
c.nonce, c.nc, c.cNonce, c.messageQop, ha2)), nil
}

func (c *credentials) kd(secret, data string) string {
func (c *digestCredentials) kd(secret, data string) string {
return c.h(fmt.Sprintf("%s:%s", secret, data))
}

// RFC 7616 3.4.2
func (c *credentials) ha1() string {
func (c *digestCredentials) ha1() string {
ret := c.h(fmt.Sprintf("%s:%s:%s", c.username, c.realm, c.password))
if c.sessionAlg {
return c.h(fmt.Sprintf("%s:%s:%s", ret, c.nonce, c.cNonce))
Expand All @@ -321,7 +317,7 @@ func (c *credentials) ha1() string {
}

// RFC 7616 3.4.3
func (c *credentials) ha2() string {
func (c *digestCredentials) ha2() string {
// currently no auth-int support
return c.h(fmt.Sprintf("%s:%s", c.method, c.digestURI))
}
5 changes: 4 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,12 @@ func Example_post() {

printOutput(resp1, err1)

type User struct {
Username, Password string
}
// POST Struct, default is JSON content type. No need to set one
resp2, err2 := client.R().
SetBody(resty.User{Username: "testuser", Password: "testpass"}).
SetBody(User{Username: "testuser", Password: "testpass"}).
SetResult(&AuthSuccess{}). // or SetResult(AuthSuccess{}).
SetError(&AuthError{}). // or SetError(AuthError{}).
Post("https://myapp.com/login")
Expand Down
27 changes: 11 additions & 16 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,28 +302,23 @@ func createHTTPRequest(c *Client, r *Request) (err error) {
}

func addCredentials(c *Client, r *Request) error {
var isBasicAuth bool
credentialsAdded := false
// Basic Auth
if r.UserInfo != nil {
r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
isBasicAuth = true
}

if !c.IsDisableWarn() {
if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
r.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
}
if r.credentials != nil {
credentialsAdded = true
r.RawRequest.SetBasicAuth(r.credentials.Username, r.credentials.Password)
}

// Build the token Auth header
if !isStringEmpty(r.AuthToken) {
var authScheme string
if isStringEmpty(r.AuthScheme) {
authScheme = "Bearer"
} else {
authScheme = r.AuthScheme
credentialsAdded = true
r.RawRequest.Header.Set(c.HeaderAuthorizationKey(), r.AuthScheme+" "+r.AuthToken)
}

if !c.IsDisableWarn() && credentialsAdded {
if strings.HasPrefix(r.URL, "http") {
r.log.Warnf("Using sensitive credentials in HTTP mode is not secure. Use HTTPS")
}
r.RawRequest.Header.Set(c.HeaderAuthorizationKey(), authScheme+" "+r.AuthToken)
}

return nil
Expand Down
12 changes: 6 additions & 6 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ type Request struct {
Result any
Error any
RawRequest *http.Request
UserInfo *User
Cookies []*http.Cookie
Debug bool
CloseConnection bool
Expand Down Expand Up @@ -77,6 +76,7 @@ type Request struct {
// first attempt + retry count = total attempts
Attempt int

credentials *credentials
isMultiPart bool
isFormData bool
setContentLength bool
Expand Down Expand Up @@ -618,7 +618,7 @@ func (r *Request) SetContentLength(l bool) *Request {
//
// It overrides the credentials set by method [Client.SetBasicAuth].
func (r *Request) SetBasicAuth(username, password string) *Request {
r.UserInfo = &User{Username: username, Password: password}
r.credentials = &credentials{Username: username, Password: password}
return r
}

Expand Down Expand Up @@ -677,8 +677,8 @@ func (r *Request) SetDigestAuth(username, password string) *Request {
oldTransport := r.client.httpClient.Transport
r.client.AddRequestMiddleware(func(c *Client, _ *Request) error {
c.httpClient.Transport = &digestTransport{
digestCredentials: digestCredentials{username, password},
transport: oldTransport,
credentials: credentials{username, password},
transport: oldTransport,
}
return nil
})
Expand Down Expand Up @@ -1400,8 +1400,8 @@ func (r *Request) Clone(ctx context.Context) *Request {
rr.RawPathParams = maps.Clone(r.RawPathParams)

// clone basic auth
if r.UserInfo != nil {
rr.UserInfo = r.UserInfo.Clone()
if r.credentials != nil {
rr.credentials = r.credentials.Clone()
}

// clone cookies
Expand Down
Loading

0 comments on commit 3b97332

Please sign in to comment.