Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added callback-based request ID checking #581

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/crewjam/saml

go 1.19
go 1.23

require (
github.com/beevik/etree v1.2.0
Expand Down
14 changes: 14 additions & 0 deletions requestid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package saml

type RequestIdCheckFunction func(string) bool

func createDefaultChecker(possibleRequestIDs []string) RequestIdCheckFunction {
return func(id string) bool {
for _, possibleRequestID := range possibleRequestIDs {
if id == possibleRequestID {
return true
}
}
return false
}
}
104 changes: 72 additions & 32 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,13 +600,17 @@ func (e ErrBadStatus) Error() string {
// ParseResponse extracts the SAML IDP response received in req, resolves
// artifacts when necessary, validates it, and returns the verified assertion.
func (sp *ServiceProvider) ParseResponse(req *http.Request, possibleRequestIDs []string) (*Assertion, error) {
return sp.ParseResponse2(req, createDefaultChecker(possibleRequestIDs))
}

func (sp *ServiceProvider) ParseResponse2(req *http.Request, checkId RequestIdCheckFunction) (*Assertion, error) {
if artifactID := req.Form.Get("SAMLart"); artifactID != "" {
return sp.handleArtifactRequest(req.Context(), artifactID, possibleRequestIDs)
return sp.handleArtifactRequest(req.Context(), artifactID, checkId)
}
return sp.parseResponseHTTP(req, possibleRequestIDs)
return sp.parseResponseHTTP(req, checkId)
}

func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, possibleRequestIDs []string) (*Assertion, error) {
func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, checkRequestId RequestIdCheckFunction) (*Assertion, error) {
retErr := &InvalidResponseError{Now: TimeNow()}

artifactResolveRequest, err := sp.MakeArtifactResolveRequest(artifactID)
Expand Down Expand Up @@ -652,14 +656,14 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID
retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err)
return nil, retErr
}
assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID)
assertion, err := sp.ParseXMLArtifactResponse2(responseBody, checkRequestId, artifactResolveRequest.ID)
if err != nil {
return nil, err
}
return assertion, nil
}

func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestIDs []string) (*Assertion, error) {
func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, idCheck RequestIdCheckFunction) (*Assertion, error) {
retErr := &InvalidResponseError{
Now: TimeNow(),
}
Expand All @@ -670,24 +674,31 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI
return nil, retErr
}

assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs)
assertion, err := sp.ParseXMLResponse2(rawResponseBuf, idCheck)
if err != nil {
return nil, err
}
return assertion, nil
}

func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) {
return sp.ParseXMLArtifactResponse2(soapResponseXML,
createDefaultChecker(possibleRequestIDs),
artifactRequestID)
}

// ParseXMLArtifactResponse validates the SAML Artifact resolver response
// and returns the verified assertion.
//

// This function handles verifying the digital signature, and verifying
// that the specified conditions and properties are met.
//
// If the function fails it will return an InvalidResponseError whose
// properties are useful in describing which part of the parsing process
// failed. However, to discourage inadvertent disclosure the diagnostic
// information, the Error() method returns a static string.
func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) {
func (sp *ServiceProvider) ParseXMLArtifactResponse2(soapResponseXML []byte, checkFunction RequestIdCheckFunction, artifactRequestID string) (*Assertion, error) {
now := TimeNow()
retErr := &InvalidResponseError{
Response: string(soapResponseXML),
Expand Down Expand Up @@ -727,10 +738,10 @@ func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, poss
return nil, retErr
}

return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now)
return sp.parseArtifactResponse(artifactResponseEl, checkFunction, artifactRequestID, now)
}

func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time) (*Assertion, error) {
func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, checkFunction RequestIdCheckFunction, artifactRequestID string, now time.Time) (*Assertion, error) {
retErr := &InvalidResponseError{
Now: now,
Response: elementToString(artifactResponseEl),
Expand Down Expand Up @@ -778,7 +789,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme
return nil, retErr
}

assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement)
assertion, err := sp.parseResponse(responseEl, checkFunction, now, signatureRequirement)
if err != nil {
retErr.PrivateErr = err
return nil, retErr
Expand All @@ -799,6 +810,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme
// failed. However, to discourage inadvertent disclosure the diagnostic
// information, the Error() method returns a static string.
func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string) (*Assertion, error) {

now := TimeNow()
var err error
retErr := &InvalidResponseError{
Expand All @@ -822,7 +834,40 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
return nil, retErr
}

assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired)
assertion, err := sp.parseResponse(doc.Root(), createDefaultChecker(possibleRequestIDs), now, signatureRequired)
if err != nil {
retErr.PrivateErr = err
return nil, retErr
}

return assertion, nil

}
func (sp *ServiceProvider) ParseXMLResponse2(decodedResponseXML []byte, checkFunction RequestIdCheckFunction) (*Assertion, error) {
now := TimeNow()
var err error
retErr := &InvalidResponseError{
Now: now,
Response: string(decodedResponseXML),
}

// ensure that the response XML is well-formed before we parse it
if err := xrv.Validate(bytes.NewReader(decodedResponseXML)); err != nil {
retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err)
return nil, retErr
}

doc := etree.NewDocument()
if err := doc.ReadFromBytes(decodedResponseXML); err != nil {
retErr.PrivateErr = err
return nil, retErr
}
if doc.Root() == nil {
retErr.PrivateErr = errors.New("invalid xml: no root")
return nil, retErr
}

assertion, err := sp.parseResponse(doc.Root(), checkFunction, now, signatureRequired)
if err != nil {
retErr.PrivateErr = err
return nil, retErr
Expand All @@ -844,7 +889,7 @@ const (
// This function handles decrypting the message, verifying the digital
// signature on the assertion, and verifying that the specified conditions
// and properties are met.
func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
var responseSignatureErr error
var responseHasSignature bool
if signatureRequirement == signatureRequired {
Expand Down Expand Up @@ -876,14 +921,10 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
if sp.AllowIDPInitiated {
requestIDvalid = true
} else {
for _, possibleRequestID := range possibleRequestIDs {
if response.InResponseTo == possibleRequestID {
requestIDvalid = true
}
}
requestIDvalid = checkFunction(response.InResponseTo)
}
if !requestIDvalid {
return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs)
return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs")
}

if response.IssueInstant.Add(MaxIssueDelay).Before(now) {
Expand Down Expand Up @@ -920,7 +961,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
return nil, err
}
for _, encryptedAssertionEl := range encryptedAssertionEls {
assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, possibleRequestIDs, now, signatureRequirement)
assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, checkFunction, now, signatureRequirement)
if err != nil {
errs = append(errs, err)
continue
Expand All @@ -936,7 +977,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
return nil, err
}
for _, assertionEl := range assertionEls {
assertion, err := sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement)
assertion, err := sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement)
if err != nil {
errs = append(errs, err)
continue
Expand All @@ -959,12 +1000,12 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
return &assertions[0], nil
}

func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
assertionEl, err := sp.decryptElement(encryptedAssertionEl)
if err != nil {
return nil, fmt.Errorf("failed to decrypt EncryptedAssertion: %v", err)
}
return sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement)
return sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement)
}

func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.Element, error) {
Expand Down Expand Up @@ -999,7 +1040,7 @@ func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.El
return doc.Root(), nil
}

func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
if signatureRequirement == signatureRequired {
sigErr := sp.validateSignature(assertionEl)
if sigErr != nil {
Expand All @@ -1013,7 +1054,7 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe
return nil, err
}

if err := sp.validateAssertion(&assertion, possibleRequestIDs, now); err != nil {
if err := sp.validateAssertion2(&assertion, checkFunction, now); err != nil {
return nil, err
}

Expand All @@ -1024,7 +1065,11 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe
// the requirements to accept. If validation fails, it returns an error describing
// the failure. (The digital signature on the assertion is not checked -- this
// should be done before calling this function).
func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleRequestIDs []string, now time.Time) error {
func (sp *ServiceProvider) validateAssertion(assertion *Assertion, allowedRequestIds []string, now time.Time) error {
return sp.validateAssertion2(assertion, createDefaultChecker(allowedRequestIds), now)
}

func (sp *ServiceProvider) validateAssertion2(assertion *Assertion, checkFunction RequestIdCheckFunction, now time.Time) error {
if assertion.IssueInstant.Add(MaxIssueDelay).Before(now) {
return fmt.Errorf("expired on %s", assertion.IssueInstant.Add(MaxIssueDelay))
}
Expand Down Expand Up @@ -1052,14 +1097,9 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque
// Finally, it is unclear that there is significant security value in checking InResponseTo when we allow
// IDP initiated assertions.
if !sp.AllowIDPInitiated {
for _, possibleRequestID := range possibleRequestIDs {
if subjectConfirmation.SubjectConfirmationData.InResponseTo == possibleRequestID {
requestIDvalid = true
break
}
}
requestIDvalid = checkFunction(subjectConfirmation.SubjectConfirmationData.InResponseTo)
if !requestIDvalid {
return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs (%v)", possibleRequestIDs)
return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs")
}
}
if subjectConfirmation.SubjectConfirmationData.Recipient != sp.AcsURL.String() {
Expand Down
34 changes: 34 additions & 0 deletions xmlenc/pubkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ func (e RSA) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, erro
// the block cipher used is AES-256 CBC and the digest method is SHA-256. You can
// specify other ciphers and digest methods by assigning to BlockCipher or
// DigestMethod.
//
// OAEP implements the older RSA-OAEP (2001 spec) for backward compatibility
func OAEP() RSA {
return RSA{
BlockCipher: AES256CBC,
Expand All @@ -139,6 +141,36 @@ func OAEP() RSA {
}
}

func OAEP_2009_256() RSA {
return RSA{
BlockCipher: AES256CBC,
DigestMethod: SHA256,
algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep",

keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil)
},
keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil)
},
}
}

func OAEP_2009_512() RSA {
return RSA{
BlockCipher: AES256CBC,
DigestMethod: SHA512,
algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep",

keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil)
},
keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil)
},
}
}

// PKCS1v15 returns a version of RSA that implements RSA in PKCS1v15 mode. By default
// the block cipher used is AES-256 CBC. The DigestMethod field is ignored because PKCS1v15
// does not use a digest function.
Expand All @@ -158,5 +190,7 @@ func PKCS1v15() RSA {

func init() {
RegisterDecrypter(OAEP())
RegisterDecrypter(OAEP_2009_256())
RegisterDecrypter(OAEP_2009_512())
RegisterDecrypter(PKCS1v15())
}
Loading