Skip to content

Commit

Permalink
Added a function-based request checker to allow users to specify thei…
Browse files Browse the repository at this point in the history
…r own request ID check functions. This allows you to check IDs more carefully than you could with just an array - for example to do things like add timeouts to the validity of the auth id.

For backward compatibility, I left the original functions in place.

Also added an additional XML type for OAEP as current shibboleth returns a different type string

old: http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p
new: http://www.w3.org/2009/xmlenc11#rsa-oaep

For backward compatibility, I left the original functions the same.
  • Loading branch information
wz2b committed Nov 29, 2024
1 parent bbccb79 commit e2cacd1
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 32 deletions.
14 changes: 14 additions & 0 deletions requestid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package saml

type RequestIdCheckFunction func(string) bool

func createDefaultChecker(possibleRequestIDs []string) RequestIdCheckFunction {
return func(id string) bool {
for _, possibleRequestID := range possibleRequestIDs {
if id == possibleRequestID {
return true
}
}
return false
}
}
100 changes: 68 additions & 32 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,13 +600,17 @@ func (e ErrBadStatus) Error() string {
// ParseResponse extracts the SAML IDP response received in req, resolves
// artifacts when necessary, validates it, and returns the verified assertion.
func (sp *ServiceProvider) ParseResponse(req *http.Request, possibleRequestIDs []string) (*Assertion, error) {
return sp.ParseResponse2(req, createDefaultChecker(possibleRequestIDs))
}

func (sp *ServiceProvider) ParseResponse2(req *http.Request, checkId RequestIdCheckFunction) (*Assertion, error) {
if artifactID := req.Form.Get("SAMLart"); artifactID != "" {
return sp.handleArtifactRequest(req.Context(), artifactID, possibleRequestIDs)
return sp.handleArtifactRequest(req.Context(), artifactID, checkId)
}
return sp.parseResponseHTTP(req, possibleRequestIDs)
return sp.parseResponseHTTP(req, checkId)
}

func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, possibleRequestIDs []string) (*Assertion, error) {
func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, checkRequestId RequestIdCheckFunction) (*Assertion, error) {
retErr := &InvalidResponseError{Now: TimeNow()}

artifactResolveRequest, err := sp.MakeArtifactResolveRequest(artifactID)
Expand Down Expand Up @@ -652,14 +656,14 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID
retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err)
return nil, retErr
}
assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID)
assertion, err := sp.ParseXMLArtifactResponse2(responseBody, checkRequestId, artifactResolveRequest.ID)
if err != nil {
return nil, err
}
return assertion, nil
}

func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestIDs []string) (*Assertion, error) {
func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, idCheck RequestIdCheckFunction) (*Assertion, error) {
retErr := &InvalidResponseError{
Now: TimeNow(),
}
Expand All @@ -670,24 +674,31 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI
return nil, retErr
}

assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs)
assertion, err := sp.ParseXMLResponse2(rawResponseBuf, idCheck)
if err != nil {
return nil, err
}
return assertion, nil
}

func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) {
return sp.ParseXMLArtifactResponse2(soapResponseXML,
createDefaultChecker(possibleRequestIDs),
artifactRequestID)
}

// ParseXMLArtifactResponse validates the SAML Artifact resolver response
// and returns the verified assertion.
//

// This function handles verifying the digital signature, and verifying
// that the specified conditions and properties are met.
//
// If the function fails it will return an InvalidResponseError whose
// properties are useful in describing which part of the parsing process
// failed. However, to discourage inadvertent disclosure the diagnostic
// information, the Error() method returns a static string.
func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) {
func (sp *ServiceProvider) ParseXMLArtifactResponse2(soapResponseXML []byte, checkFunction RequestIdCheckFunction, artifactRequestID string) (*Assertion, error) {
now := TimeNow()
retErr := &InvalidResponseError{
Response: string(soapResponseXML),
Expand Down Expand Up @@ -727,10 +738,10 @@ func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, poss
return nil, retErr
}

return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now)
return sp.parseArtifactResponse(artifactResponseEl, checkFunction, artifactRequestID, now)
}

func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time) (*Assertion, error) {
func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, checkFunction RequestIdCheckFunction, artifactRequestID string, now time.Time) (*Assertion, error) {
retErr := &InvalidResponseError{
Now: now,
Response: elementToString(artifactResponseEl),
Expand Down Expand Up @@ -778,7 +789,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme
return nil, retErr
}

assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement)
assertion, err := sp.parseResponse(responseEl, checkFunction, now, signatureRequirement)
if err != nil {
retErr.PrivateErr = err
return nil, retErr
Expand All @@ -799,6 +810,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme
// failed. However, to discourage inadvertent disclosure the diagnostic
// information, the Error() method returns a static string.
func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string) (*Assertion, error) {

now := TimeNow()
var err error
retErr := &InvalidResponseError{
Expand All @@ -822,7 +834,40 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
return nil, retErr
}

assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired)
assertion, err := sp.parseResponse(doc.Root(), createDefaultChecker(possibleRequestIDs), now, signatureRequired)
if err != nil {
retErr.PrivateErr = err
return nil, retErr
}

return assertion, nil

}
func (sp *ServiceProvider) ParseXMLResponse2(decodedResponseXML []byte, checkFunction RequestIdCheckFunction) (*Assertion, error) {
now := TimeNow()
var err error
retErr := &InvalidResponseError{
Now: now,
Response: string(decodedResponseXML),
}

// ensure that the response XML is well-formed before we parse it
if err := xrv.Validate(bytes.NewReader(decodedResponseXML)); err != nil {
retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err)
return nil, retErr
}

doc := etree.NewDocument()
if err := doc.ReadFromBytes(decodedResponseXML); err != nil {
retErr.PrivateErr = err
return nil, retErr
}
if doc.Root() == nil {
retErr.PrivateErr = errors.New("invalid xml: no root")
return nil, retErr
}

assertion, err := sp.parseResponse(doc.Root(), checkFunction, now, signatureRequired)
if err != nil {
retErr.PrivateErr = err
return nil, retErr
Expand All @@ -844,7 +889,7 @@ const (
// This function handles decrypting the message, verifying the digital
// signature on the assertion, and verifying that the specified conditions
// and properties are met.
func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
var responseSignatureErr error
var responseHasSignature bool
if signatureRequirement == signatureRequired {
Expand Down Expand Up @@ -876,14 +921,10 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
if sp.AllowIDPInitiated {
requestIDvalid = true
} else {
for _, possibleRequestID := range possibleRequestIDs {
if response.InResponseTo == possibleRequestID {
requestIDvalid = true
}
}
requestIDvalid = checkFunction(response.InResponseTo)
}
if !requestIDvalid {
return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs)
return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs")
}

if response.IssueInstant.Add(MaxIssueDelay).Before(now) {
Expand Down Expand Up @@ -920,7 +961,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
return nil, err
}
for _, encryptedAssertionEl := range encryptedAssertionEls {
assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, possibleRequestIDs, now, signatureRequirement)
assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, checkFunction, now, signatureRequirement)
if err != nil {
errs = append(errs, err)
continue
Expand All @@ -936,7 +977,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
return nil, err
}
for _, assertionEl := range assertionEls {
assertion, err := sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement)
assertion, err := sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement)
if err != nil {
errs = append(errs, err)
continue
Expand All @@ -959,12 +1000,12 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
return &assertions[0], nil
}

func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
assertionEl, err := sp.decryptElement(encryptedAssertionEl)
if err != nil {
return nil, fmt.Errorf("failed to decrypt EncryptedAssertion: %v", err)
}
return sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement)
return sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement)
}

func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.Element, error) {
Expand Down Expand Up @@ -999,7 +1040,7 @@ func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.El
return doc.Root(), nil
}

func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
if signatureRequirement == signatureRequired {
sigErr := sp.validateSignature(assertionEl)
if sigErr != nil {
Expand All @@ -1013,7 +1054,7 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe
return nil, err
}

if err := sp.validateAssertion(&assertion, possibleRequestIDs, now); err != nil {
if err := sp.validateAssertion(&assertion, checkFunction, now); err != nil {
return nil, err
}

Expand All @@ -1024,7 +1065,7 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe
// the requirements to accept. If validation fails, it returns an error describing
// the failure. (The digital signature on the assertion is not checked -- this
// should be done before calling this function).
func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleRequestIDs []string, now time.Time) error {
func (sp *ServiceProvider) validateAssertion(assertion *Assertion, checkFunction RequestIdCheckFunction, now time.Time) error {
if assertion.IssueInstant.Add(MaxIssueDelay).Before(now) {
return fmt.Errorf("expired on %s", assertion.IssueInstant.Add(MaxIssueDelay))
}
Expand Down Expand Up @@ -1052,14 +1093,9 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque
// Finally, it is unclear that there is significant security value in checking InResponseTo when we allow
// IDP initiated assertions.
if !sp.AllowIDPInitiated {
for _, possibleRequestID := range possibleRequestIDs {
if subjectConfirmation.SubjectConfirmationData.InResponseTo == possibleRequestID {
requestIDvalid = true
break
}
}
requestIDvalid = checkFunction(subjectConfirmation.SubjectConfirmationData.InResponseTo)
if !requestIDvalid {
return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs (%v)", possibleRequestIDs)
return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs")
}
}
if subjectConfirmation.SubjectConfirmationData.Recipient != sp.AcsURL.String() {
Expand Down
34 changes: 34 additions & 0 deletions xmlenc/pubkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ func (e RSA) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, erro
// the block cipher used is AES-256 CBC and the digest method is SHA-256. You can
// specify other ciphers and digest methods by assigning to BlockCipher or
// DigestMethod.
//
// OAEP implements the older RSA-OAEP (2001 spec) for backward compatibility
func OAEP() RSA {
return RSA{
BlockCipher: AES256CBC,
Expand All @@ -139,6 +141,36 @@ func OAEP() RSA {
}
}

func OAEP_2009_256() RSA {
return RSA{
BlockCipher: AES256CBC,
DigestMethod: SHA256,
algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep",

keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil)
},
keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil)
},
}
}

func OAEP_2009_512() RSA {
return RSA{
BlockCipher: AES256CBC,
DigestMethod: SHA512,
algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep",

keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil)
},
keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil)
},
}
}

// PKCS1v15 returns a version of RSA that implements RSA in PKCS1v15 mode. By default
// the block cipher used is AES-256 CBC. The DigestMethod field is ignored because PKCS1v15
// does not use a digest function.
Expand All @@ -158,5 +190,7 @@ func PKCS1v15() RSA {

func init() {
RegisterDecrypter(OAEP())
RegisterDecrypter(OAEP_2009_256())
RegisterDecrypter(OAEP_2009_512())
RegisterDecrypter(PKCS1v15())
}

0 comments on commit e2cacd1

Please sign in to comment.