From e2cacd138446526e057b736203717b793f6f08a1 Mon Sep 17 00:00:00 2001 From: Christopher Piggott Date: Fri, 29 Nov 2024 20:34:15 +0000 Subject: [PATCH 1/4] Added a function-based request checker to allow users to specify their own request ID check functions. This allows you to check IDs more carefully than you could with just an array - for example to do things like add timeouts to the validity of the auth id. For backward compatibility, I left the original functions in place. Also added an additional XML type for OAEP as current shibboleth returns a different type string old: http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p new: http://www.w3.org/2009/xmlenc11#rsa-oaep For backward compatibility, I left the original functions the same. --- requestid.go | 14 +++++++ service_provider.go | 100 ++++++++++++++++++++++++++++++-------------- xmlenc/pubkey.go | 34 +++++++++++++++ 3 files changed, 116 insertions(+), 32 deletions(-) create mode 100644 requestid.go diff --git a/requestid.go b/requestid.go new file mode 100644 index 00000000..bd3ab1b9 --- /dev/null +++ b/requestid.go @@ -0,0 +1,14 @@ +package saml + +type RequestIdCheckFunction func(string) bool + +func createDefaultChecker(possibleRequestIDs []string) RequestIdCheckFunction { + return func(id string) bool { + for _, possibleRequestID := range possibleRequestIDs { + if id == possibleRequestID { + return true + } + } + return false + } +} diff --git a/service_provider.go b/service_provider.go index 30b35670..b6d87827 100644 --- a/service_provider.go +++ b/service_provider.go @@ -600,13 +600,17 @@ func (e ErrBadStatus) Error() string { // ParseResponse extracts the SAML IDP response received in req, resolves // artifacts when necessary, validates it, and returns the verified assertion. func (sp *ServiceProvider) ParseResponse(req *http.Request, possibleRequestIDs []string) (*Assertion, error) { + return sp.ParseResponse2(req, createDefaultChecker(possibleRequestIDs)) +} + +func (sp *ServiceProvider) ParseResponse2(req *http.Request, checkId RequestIdCheckFunction) (*Assertion, error) { if artifactID := req.Form.Get("SAMLart"); artifactID != "" { - return sp.handleArtifactRequest(req.Context(), artifactID, possibleRequestIDs) + return sp.handleArtifactRequest(req.Context(), artifactID, checkId) } - return sp.parseResponseHTTP(req, possibleRequestIDs) + return sp.parseResponseHTTP(req, checkId) } -func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, possibleRequestIDs []string) (*Assertion, error) { +func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, checkRequestId RequestIdCheckFunction) (*Assertion, error) { retErr := &InvalidResponseError{Now: TimeNow()} artifactResolveRequest, err := sp.MakeArtifactResolveRequest(artifactID) @@ -652,14 +656,14 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err) return nil, retErr } - assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID) + assertion, err := sp.ParseXMLArtifactResponse2(responseBody, checkRequestId, artifactResolveRequest.ID) if err != nil { return nil, err } return assertion, nil } -func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestIDs []string) (*Assertion, error) { +func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, idCheck RequestIdCheckFunction) (*Assertion, error) { retErr := &InvalidResponseError{ Now: TimeNow(), } @@ -670,16 +674,23 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI return nil, retErr } - assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs) + assertion, err := sp.ParseXMLResponse2(rawResponseBuf, idCheck) if err != nil { return nil, err } return assertion, nil } +func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) { + return sp.ParseXMLArtifactResponse2(soapResponseXML, + createDefaultChecker(possibleRequestIDs), + artifactRequestID) +} + // ParseXMLArtifactResponse validates the SAML Artifact resolver response // and returns the verified assertion. // + // This function handles verifying the digital signature, and verifying // that the specified conditions and properties are met. // @@ -687,7 +698,7 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI // properties are useful in describing which part of the parsing process // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. -func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) { +func (sp *ServiceProvider) ParseXMLArtifactResponse2(soapResponseXML []byte, checkFunction RequestIdCheckFunction, artifactRequestID string) (*Assertion, error) { now := TimeNow() retErr := &InvalidResponseError{ Response: string(soapResponseXML), @@ -727,10 +738,10 @@ func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, poss return nil, retErr } - return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now) + return sp.parseArtifactResponse(artifactResponseEl, checkFunction, artifactRequestID, now) } -func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time) (*Assertion, error) { +func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, checkFunction RequestIdCheckFunction, artifactRequestID string, now time.Time) (*Assertion, error) { retErr := &InvalidResponseError{ Now: now, Response: elementToString(artifactResponseEl), @@ -778,7 +789,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme return nil, retErr } - assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseResponse(responseEl, checkFunction, now, signatureRequirement) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -799,6 +810,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string) (*Assertion, error) { + now := TimeNow() var err error retErr := &InvalidResponseError{ @@ -822,7 +834,40 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR return nil, retErr } - assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired) + assertion, err := sp.parseResponse(doc.Root(), createDefaultChecker(possibleRequestIDs), now, signatureRequired) + if err != nil { + retErr.PrivateErr = err + return nil, retErr + } + + return assertion, nil + +} +func (sp *ServiceProvider) ParseXMLResponse2(decodedResponseXML []byte, checkFunction RequestIdCheckFunction) (*Assertion, error) { + now := TimeNow() + var err error + retErr := &InvalidResponseError{ + Now: now, + Response: string(decodedResponseXML), + } + + // ensure that the response XML is well-formed before we parse it + if err := xrv.Validate(bytes.NewReader(decodedResponseXML)); err != nil { + retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err) + return nil, retErr + } + + doc := etree.NewDocument() + if err := doc.ReadFromBytes(decodedResponseXML); err != nil { + retErr.PrivateErr = err + return nil, retErr + } + if doc.Root() == nil { + retErr.PrivateErr = errors.New("invalid xml: no root") + return nil, retErr + } + + assertion, err := sp.parseResponse(doc.Root(), checkFunction, now, signatureRequired) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -844,7 +889,7 @@ const ( // This function handles decrypting the message, verifying the digital // signature on the assertion, and verifying that the specified conditions // and properties are met. -func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { var responseSignatureErr error var responseHasSignature bool if signatureRequirement == signatureRequired { @@ -876,14 +921,10 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ if sp.AllowIDPInitiated { requestIDvalid = true } else { - for _, possibleRequestID := range possibleRequestIDs { - if response.InResponseTo == possibleRequestID { - requestIDvalid = true - } - } + requestIDvalid = checkFunction(response.InResponseTo) } if !requestIDvalid { - return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs) + return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs") } if response.IssueInstant.Add(MaxIssueDelay).Before(now) { @@ -920,7 +961,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ return nil, err } for _, encryptedAssertionEl := range encryptedAssertionEls { - assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, checkFunction, now, signatureRequirement) if err != nil { errs = append(errs, err) continue @@ -936,7 +977,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ return nil, err } for _, assertionEl := range assertionEls { - assertion, err := sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement) if err != nil { errs = append(errs, err) continue @@ -959,12 +1000,12 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ return &assertions[0], nil } -func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { assertionEl, err := sp.decryptElement(encryptedAssertionEl) if err != nil { return nil, fmt.Errorf("failed to decrypt EncryptedAssertion: %v", err) } - return sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement) + return sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement) } func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.Element, error) { @@ -999,7 +1040,7 @@ func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.El return doc.Root(), nil } -func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { if signatureRequirement == signatureRequired { sigErr := sp.validateSignature(assertionEl) if sigErr != nil { @@ -1013,7 +1054,7 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe return nil, err } - if err := sp.validateAssertion(&assertion, possibleRequestIDs, now); err != nil { + if err := sp.validateAssertion(&assertion, checkFunction, now); err != nil { return nil, err } @@ -1024,7 +1065,7 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe // the requirements to accept. If validation fails, it returns an error describing // the failure. (The digital signature on the assertion is not checked -- this // should be done before calling this function). -func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleRequestIDs []string, now time.Time) error { +func (sp *ServiceProvider) validateAssertion(assertion *Assertion, checkFunction RequestIdCheckFunction, now time.Time) error { if assertion.IssueInstant.Add(MaxIssueDelay).Before(now) { return fmt.Errorf("expired on %s", assertion.IssueInstant.Add(MaxIssueDelay)) } @@ -1052,14 +1093,9 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque // Finally, it is unclear that there is significant security value in checking InResponseTo when we allow // IDP initiated assertions. if !sp.AllowIDPInitiated { - for _, possibleRequestID := range possibleRequestIDs { - if subjectConfirmation.SubjectConfirmationData.InResponseTo == possibleRequestID { - requestIDvalid = true - break - } - } + requestIDvalid = checkFunction(subjectConfirmation.SubjectConfirmationData.InResponseTo) if !requestIDvalid { - return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs (%v)", possibleRequestIDs) + return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs") } } if subjectConfirmation.SubjectConfirmationData.Recipient != sp.AcsURL.String() { diff --git a/xmlenc/pubkey.go b/xmlenc/pubkey.go index 13d4d9e7..e1daee93 100644 --- a/xmlenc/pubkey.go +++ b/xmlenc/pubkey.go @@ -125,6 +125,8 @@ func (e RSA) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, erro // the block cipher used is AES-256 CBC and the digest method is SHA-256. You can // specify other ciphers and digest methods by assigning to BlockCipher or // DigestMethod. +// +// OAEP implements the older RSA-OAEP (2001 spec) for backward compatibility func OAEP() RSA { return RSA{ BlockCipher: AES256CBC, @@ -139,6 +141,36 @@ func OAEP() RSA { } } +func OAEP_2009_256() RSA { + return RSA{ + BlockCipher: AES256CBC, + DigestMethod: SHA256, + algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep", + + keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil) + }, + keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { + return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil) + }, + } +} + +func OAEP_2009_512() RSA { + return RSA{ + BlockCipher: AES256CBC, + DigestMethod: SHA512, + algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep", + + keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil) + }, + keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { + return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil) + }, + } +} + // PKCS1v15 returns a version of RSA that implements RSA in PKCS1v15 mode. By default // the block cipher used is AES-256 CBC. The DigestMethod field is ignored because PKCS1v15 // does not use a digest function. @@ -158,5 +190,7 @@ func PKCS1v15() RSA { func init() { RegisterDecrypter(OAEP()) + RegisterDecrypter(OAEP_2009_256()) + RegisterDecrypter(OAEP_2009_512()) RegisterDecrypter(PKCS1v15()) } From 8924fc8332bd3e6d08311a1a028670c620c4a6a5 Mon Sep 17 00:00:00 2001 From: Christopher Piggott Date: Fri, 29 Nov 2024 21:12:24 +0000 Subject: [PATCH 2/4] added original signature version of validateAssertion() for tests to be happy --- service_provider.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/service_provider.go b/service_provider.go index b6d87827..df584d79 100644 --- a/service_provider.go +++ b/service_provider.go @@ -1054,7 +1054,7 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, checkFunct return nil, err } - if err := sp.validateAssertion(&assertion, checkFunction, now); err != nil { + if err := sp.validateAssertion2(&assertion, checkFunction, now); err != nil { return nil, err } @@ -1065,7 +1065,11 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, checkFunct // the requirements to accept. If validation fails, it returns an error describing // the failure. (The digital signature on the assertion is not checked -- this // should be done before calling this function). -func (sp *ServiceProvider) validateAssertion(assertion *Assertion, checkFunction RequestIdCheckFunction, now time.Time) error { +func (sp *ServiceProvider) validateAssertion(assertion *Assertion, allowedRequestIds []string, now time.Time) error { + return sp.validateAssertion2(assertion, createDefaultChecker(allowedRequestIds), now) +} + +func (sp *ServiceProvider) validateAssertion2(assertion *Assertion, checkFunction RequestIdCheckFunction, now time.Time) error { if assertion.IssueInstant.Add(MaxIssueDelay).Before(now) { return fmt.Errorf("expired on %s", assertion.IssueInstant.Add(MaxIssueDelay)) } From 41d38e8c440aad11f7032f6d109a9ffbea63ded4 Mon Sep 17 00:00:00 2001 From: Christopher Piggott Date: Tue, 3 Dec 2024 18:59:09 +0000 Subject: [PATCH 3/4] change name of project temporarily --- go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 39f9c8d2..c6dc8472 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ -module github.com/crewjam/saml +module github.com/wz2b/saml -go 1.19 +go 1.23 require ( github.com/beevik/etree v1.2.0 From 5c1ab6f4685688dcfbbfcb3e2afdd58627aa28e4 Mon Sep 17 00:00:00 2001 From: Christopher Piggott Date: Tue, 3 Dec 2024 19:15:59 +0000 Subject: [PATCH 4/4] change name back to original --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c6dc8472..bb6d819f 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/wz2b/saml +module github.com/crewjam/saml go 1.23