diff --git a/go.mod b/go.mod index 39f9c8d2..bb6d819f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/crewjam/saml -go 1.19 +go 1.23 require ( github.com/beevik/etree v1.2.0 diff --git a/requestid.go b/requestid.go new file mode 100644 index 00000000..bd3ab1b9 --- /dev/null +++ b/requestid.go @@ -0,0 +1,14 @@ +package saml + +type RequestIdCheckFunction func(string) bool + +func createDefaultChecker(possibleRequestIDs []string) RequestIdCheckFunction { + return func(id string) bool { + for _, possibleRequestID := range possibleRequestIDs { + if id == possibleRequestID { + return true + } + } + return false + } +} diff --git a/service_provider.go b/service_provider.go index 30b35670..df584d79 100644 --- a/service_provider.go +++ b/service_provider.go @@ -600,13 +600,17 @@ func (e ErrBadStatus) Error() string { // ParseResponse extracts the SAML IDP response received in req, resolves // artifacts when necessary, validates it, and returns the verified assertion. func (sp *ServiceProvider) ParseResponse(req *http.Request, possibleRequestIDs []string) (*Assertion, error) { + return sp.ParseResponse2(req, createDefaultChecker(possibleRequestIDs)) +} + +func (sp *ServiceProvider) ParseResponse2(req *http.Request, checkId RequestIdCheckFunction) (*Assertion, error) { if artifactID := req.Form.Get("SAMLart"); artifactID != "" { - return sp.handleArtifactRequest(req.Context(), artifactID, possibleRequestIDs) + return sp.handleArtifactRequest(req.Context(), artifactID, checkId) } - return sp.parseResponseHTTP(req, possibleRequestIDs) + return sp.parseResponseHTTP(req, checkId) } -func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, possibleRequestIDs []string) (*Assertion, error) { +func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, checkRequestId RequestIdCheckFunction) (*Assertion, error) { retErr := &InvalidResponseError{Now: TimeNow()} artifactResolveRequest, err := sp.MakeArtifactResolveRequest(artifactID) @@ -652,14 +656,14 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err) return nil, retErr } - assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID) + assertion, err := sp.ParseXMLArtifactResponse2(responseBody, checkRequestId, artifactResolveRequest.ID) if err != nil { return nil, err } return assertion, nil } -func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestIDs []string) (*Assertion, error) { +func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, idCheck RequestIdCheckFunction) (*Assertion, error) { retErr := &InvalidResponseError{ Now: TimeNow(), } @@ -670,16 +674,23 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI return nil, retErr } - assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs) + assertion, err := sp.ParseXMLResponse2(rawResponseBuf, idCheck) if err != nil { return nil, err } return assertion, nil } +func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) { + return sp.ParseXMLArtifactResponse2(soapResponseXML, + createDefaultChecker(possibleRequestIDs), + artifactRequestID) +} + // ParseXMLArtifactResponse validates the SAML Artifact resolver response // and returns the verified assertion. // + // This function handles verifying the digital signature, and verifying // that the specified conditions and properties are met. // @@ -687,7 +698,7 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI // properties are useful in describing which part of the parsing process // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. -func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) { +func (sp *ServiceProvider) ParseXMLArtifactResponse2(soapResponseXML []byte, checkFunction RequestIdCheckFunction, artifactRequestID string) (*Assertion, error) { now := TimeNow() retErr := &InvalidResponseError{ Response: string(soapResponseXML), @@ -727,10 +738,10 @@ func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, poss return nil, retErr } - return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now) + return sp.parseArtifactResponse(artifactResponseEl, checkFunction, artifactRequestID, now) } -func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time) (*Assertion, error) { +func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, checkFunction RequestIdCheckFunction, artifactRequestID string, now time.Time) (*Assertion, error) { retErr := &InvalidResponseError{ Now: now, Response: elementToString(artifactResponseEl), @@ -778,7 +789,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme return nil, retErr } - assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseResponse(responseEl, checkFunction, now, signatureRequirement) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -799,6 +810,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string) (*Assertion, error) { + now := TimeNow() var err error retErr := &InvalidResponseError{ @@ -822,7 +834,40 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR return nil, retErr } - assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired) + assertion, err := sp.parseResponse(doc.Root(), createDefaultChecker(possibleRequestIDs), now, signatureRequired) + if err != nil { + retErr.PrivateErr = err + return nil, retErr + } + + return assertion, nil + +} +func (sp *ServiceProvider) ParseXMLResponse2(decodedResponseXML []byte, checkFunction RequestIdCheckFunction) (*Assertion, error) { + now := TimeNow() + var err error + retErr := &InvalidResponseError{ + Now: now, + Response: string(decodedResponseXML), + } + + // ensure that the response XML is well-formed before we parse it + if err := xrv.Validate(bytes.NewReader(decodedResponseXML)); err != nil { + retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err) + return nil, retErr + } + + doc := etree.NewDocument() + if err := doc.ReadFromBytes(decodedResponseXML); err != nil { + retErr.PrivateErr = err + return nil, retErr + } + if doc.Root() == nil { + retErr.PrivateErr = errors.New("invalid xml: no root") + return nil, retErr + } + + assertion, err := sp.parseResponse(doc.Root(), checkFunction, now, signatureRequired) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -844,7 +889,7 @@ const ( // This function handles decrypting the message, verifying the digital // signature on the assertion, and verifying that the specified conditions // and properties are met. -func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { var responseSignatureErr error var responseHasSignature bool if signatureRequirement == signatureRequired { @@ -876,14 +921,10 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ if sp.AllowIDPInitiated { requestIDvalid = true } else { - for _, possibleRequestID := range possibleRequestIDs { - if response.InResponseTo == possibleRequestID { - requestIDvalid = true - } - } + requestIDvalid = checkFunction(response.InResponseTo) } if !requestIDvalid { - return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs) + return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs") } if response.IssueInstant.Add(MaxIssueDelay).Before(now) { @@ -920,7 +961,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ return nil, err } for _, encryptedAssertionEl := range encryptedAssertionEls { - assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, checkFunction, now, signatureRequirement) if err != nil { errs = append(errs, err) continue @@ -936,7 +977,7 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ return nil, err } for _, assertionEl := range assertionEls { - assertion, err := sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement) if err != nil { errs = append(errs, err) continue @@ -959,12 +1000,12 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ return &assertions[0], nil } -func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { assertionEl, err := sp.decryptElement(encryptedAssertionEl) if err != nil { return nil, fmt.Errorf("failed to decrypt EncryptedAssertion: %v", err) } - return sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement) + return sp.parseAssertion(assertionEl, checkFunction, now, signatureRequirement) } func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.Element, error) { @@ -999,7 +1040,7 @@ func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.El return doc.Root(), nil } -func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, checkFunction RequestIdCheckFunction, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { if signatureRequirement == signatureRequired { sigErr := sp.validateSignature(assertionEl) if sigErr != nil { @@ -1013,7 +1054,7 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe return nil, err } - if err := sp.validateAssertion(&assertion, possibleRequestIDs, now); err != nil { + if err := sp.validateAssertion2(&assertion, checkFunction, now); err != nil { return nil, err } @@ -1024,7 +1065,11 @@ func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRe // the requirements to accept. If validation fails, it returns an error describing // the failure. (The digital signature on the assertion is not checked -- this // should be done before calling this function). -func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleRequestIDs []string, now time.Time) error { +func (sp *ServiceProvider) validateAssertion(assertion *Assertion, allowedRequestIds []string, now time.Time) error { + return sp.validateAssertion2(assertion, createDefaultChecker(allowedRequestIds), now) +} + +func (sp *ServiceProvider) validateAssertion2(assertion *Assertion, checkFunction RequestIdCheckFunction, now time.Time) error { if assertion.IssueInstant.Add(MaxIssueDelay).Before(now) { return fmt.Errorf("expired on %s", assertion.IssueInstant.Add(MaxIssueDelay)) } @@ -1052,14 +1097,9 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque // Finally, it is unclear that there is significant security value in checking InResponseTo when we allow // IDP initiated assertions. if !sp.AllowIDPInitiated { - for _, possibleRequestID := range possibleRequestIDs { - if subjectConfirmation.SubjectConfirmationData.InResponseTo == possibleRequestID { - requestIDvalid = true - break - } - } + requestIDvalid = checkFunction(subjectConfirmation.SubjectConfirmationData.InResponseTo) if !requestIDvalid { - return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs (%v)", possibleRequestIDs) + return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs") } } if subjectConfirmation.SubjectConfirmationData.Recipient != sp.AcsURL.String() { diff --git a/xmlenc/pubkey.go b/xmlenc/pubkey.go index 13d4d9e7..e1daee93 100644 --- a/xmlenc/pubkey.go +++ b/xmlenc/pubkey.go @@ -125,6 +125,8 @@ func (e RSA) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, erro // the block cipher used is AES-256 CBC and the digest method is SHA-256. You can // specify other ciphers and digest methods by assigning to BlockCipher or // DigestMethod. +// +// OAEP implements the older RSA-OAEP (2001 spec) for backward compatibility func OAEP() RSA { return RSA{ BlockCipher: AES256CBC, @@ -139,6 +141,36 @@ func OAEP() RSA { } } +func OAEP_2009_256() RSA { + return RSA{ + BlockCipher: AES256CBC, + DigestMethod: SHA256, + algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep", + + keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil) + }, + keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { + return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil) + }, + } +} + +func OAEP_2009_512() RSA { + return RSA{ + BlockCipher: AES256CBC, + DigestMethod: SHA512, + algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep", + + keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil) + }, + keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { + return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil) + }, + } +} + // PKCS1v15 returns a version of RSA that implements RSA in PKCS1v15 mode. By default // the block cipher used is AES-256 CBC. The DigestMethod field is ignored because PKCS1v15 // does not use a digest function. @@ -158,5 +190,7 @@ func PKCS1v15() RSA { func init() { RegisterDecrypter(OAEP()) + RegisterDecrypter(OAEP_2009_256()) + RegisterDecrypter(OAEP_2009_512()) RegisterDecrypter(PKCS1v15()) }