diff --git a/example/trivial/trivial.go b/example/trivial/trivial.go index 6aa478a0..e8be7cb9 100644 --- a/example/trivial/trivial.go +++ b/example/trivial/trivial.go @@ -27,7 +27,7 @@ func main() { } rootURL, _ := url.Parse("http://localhost:8000") - idpMetadataURL, _ := url.Parse("https://www.testshib.org/metadata/testshib-providers.xml") + idpMetadataURL, _ := url.Parse("https://samltest.id/saml/idp") idpMetadata, err := samlsp.FetchMetadata( context.Background(), @@ -42,6 +42,7 @@ func main() { IDPMetadata: idpMetadata, Key: keyPair.PrivateKey.(*rsa.PrivateKey), Certificate: keyPair.Leaf, + SignRequest: true, }) if err != nil { panic(err) // TODO handle error diff --git a/samlsp/new.go b/samlsp/new.go index 451a65aa..a28a46b4 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -5,6 +5,7 @@ import ( "context" "crypto/rsa" "crypto/x509" + dsig "github.com/russellhaering/goxmldsig" "net/http" "net/url" "time" @@ -22,6 +23,7 @@ type Options struct { Intermediates []*x509.Certificate AllowIDPInitiated bool IDPMetadata *saml.EntityDescriptor + SignRequest bool ForceAuthn bool // TODO(ross): this should be *bool // The following fields exist <= 0.3.0, but are superceded by the new @@ -125,6 +127,10 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { if opts.ForceAuthn { forceAuthn = &opts.ForceAuthn } + signatureMethod := dsig.RSASHA1SignatureMethod + if !opts.SignRequest { + signatureMethod = "" + } return saml.ServiceProvider{ EntityID: opts.EntityID, @@ -136,6 +142,7 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { SloURL: *sloURL, IDPMetadata: opts.IDPMetadata, ForceAuthn: forceAuthn, + SignatureMethod: signatureMethod, AllowIDPInitiated: opts.AllowIDPInitiated, } } diff --git a/service_provider.go b/service_provider.go index 04d3eddf..a8b64a58 100644 --- a/service_provider.go +++ b/service_provider.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/flate" "crypto/rsa" + "crypto/tls" "crypto/x509" "encoding/base64" "encoding/xml" @@ -101,6 +102,9 @@ type ServiceProvider struct { // SignatureVerifier, if non-nil, allows you to implement an alternative way // to verify signatures. SignatureVerifier SignatureVerifier + + // SignatureMethod, if non-empty, authentication requests will be signed + SignatureMethod string } // MaxIssueDelay is the longest allowed time between when a SAML assertion is @@ -126,7 +130,7 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor { validDuration = sp.MetadataValidDuration } - authnRequestsSigned := false + authnRequestsSigned := len(sp.SignatureMethod) > 0 wantAssertionsSigned := true validUntil := TimeNow().Add(validDuration) @@ -137,12 +141,6 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor { certBytes = append(certBytes, intermediate.Raw...) } keyDescriptors = []KeyDescriptor{ - { - Use: "signing", - KeyInfo: KeyInfo{ - Certificate: base64.StdEncoding.EncodeToString(certBytes), - }, - }, { Use: "encryption", KeyInfo: KeyInfo{ @@ -156,6 +154,14 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor { }, }, } + if len(sp.SignatureMethod) > 0 { + keyDescriptors = append(keyDescriptors, KeyDescriptor{ + Use: "signing", + KeyInfo: KeyInfo{ + Certificate: base64.StdEncoding.EncodeToString(certBytes), + }, + }) + } } return &EntityDescriptor{ @@ -330,9 +336,51 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnReque }, ForceAuthn: sp.ForceAuthn, } + if len(sp.SignatureMethod) > 0 { + if err := sp.SignAuthnRequest(&req); err != nil { + return nil, err + } + } return &req, nil } +// SignAuthnRequest adds the `Signature` element to the `AuthnRequest`. +func (sp *ServiceProvider) SignAuthnRequest(req *AuthnRequest) error { + keyPair := tls.Certificate{ + Certificate: [][]byte{sp.Certificate.Raw}, + PrivateKey: sp.Key, + Leaf: sp.Certificate, + } + // TODO: add intermediates for SP + //for _, cert := range sp.Intermediates { + // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) + //} + keyStore := dsig.TLSCertKeyStore(keyPair) + + if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && + sp.SignatureMethod != dsig.RSASHA256SignatureMethod && + sp.SignatureMethod != dsig.RSASHA512SignatureMethod { + return fmt.Errorf("invalid signing method %s", sp.SignatureMethod) + } + signatureMethod := sp.SignatureMethod + signingContext := dsig.NewDefaultSigningContext(keyStore) + signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) + if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + return err + } + + assertionEl := req.Element() + + signedRequestEl, err := signingContext.SignEnveloped(assertionEl) + if err != nil { + return err + } + + sigEl := signedRequestEl.Child[len(signedRequestEl.Child)-1] + req.Signature = sigEl.(*etree.Element) + return nil +} + // MakePostAuthenticationRequest creates a SAML authentication request using // the HTTP-POST binding. It returns HTML text representing an HTML form that // can be sent presented to a browser to initiate the login process. diff --git a/service_provider_test.go b/service_provider_test.go index ff908ceb..ac831215 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -98,7 +98,7 @@ func TestSPCanSetAuthenticationNameIDFormat(t *testing.T) { assert.Equal(t, string(EmailAddressNameIDFormat), *req.NameIDPolicy.Format) } -func TestSPCanProduceMetadata(t *testing.T) { +func TestSPCanProduceMetadataWithEncryptionCert(t *testing.T) { test := NewServiceProviderTest() s := ServiceProvider{ Key: test.Key, @@ -116,13 +116,43 @@ func TestSPCanProduceMetadata(t *testing.T) { assert.Equal(t, ""+ "\n"+ " \n"+ - " \n"+ + " \n"+ " \n"+ " \n"+ " MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ==\n"+ " \n"+ " \n"+ + " \n"+ + " \n"+ + " \n"+ + " \n"+ " \n"+ + " \n"+ + " \n"+ + " \n"+ + "", + string(spMetadata)) +} + +func TestSPCanProduceMetadataWithBothCerts(t *testing.T) { + test := NewServiceProviderTest() + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://example.com/saml2/metadata"), + AcsURL: mustParseURL("https://example.com/saml2/acs"), + SloURL: mustParseURL("https://example.com/saml2/slo"), + IDPMetadata: &EntityDescriptor{}, + SignatureMethod: "not-empty", + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + assert.NoError(t, err) + + spMetadata, err := xml.MarshalIndent(s.Metadata(), "", " ") + assert.NoError(t, err) + assert.Equal(t, ""+ + "\n"+ + " \n"+ " \n"+ " \n"+ " \n"+ @@ -134,6 +164,13 @@ func TestSPCanProduceMetadata(t *testing.T) { " \n"+ " \n"+ " \n"+ + " \n"+ + " \n"+ + " \n"+ + " MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ==\n"+ + " \n"+ + " \n"+ + " \n"+ " \n"+ " \n"+ " \n"+ @@ -141,7 +178,7 @@ func TestSPCanProduceMetadata(t *testing.T) { string(spMetadata)) } -func TestCanProduceMetadataNoSigningKey(t *testing.T) { +func TestCanProduceMetadataNoCerts(t *testing.T) { test := NewServiceProviderTest() s := ServiceProvider{ MetadataURL: mustParseURL("https://example.com/saml2/metadata"), @@ -248,6 +285,62 @@ func TestSPCanProducePostRequest(t *testing.T) { string(form)) } +func TestSPCanProduceSignedRequest(t *testing.T) { + test := NewServiceProviderTest() + TimeNow = func() time.Time { + rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 UTC 2006", "Mon Dec 1 01:31:21.123456789 UTC 2015") + return rv + } + Clock = dsig.NewFakeClockAt(TimeNow()) + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + SignatureMethod: dsig.RSASHA1SignatureMethod, + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + assert.NoError(t, err) + + redirectURL, err := s.MakeRedirectAuthenticationRequest("relayState") + assert.NoError(t, err) + + decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) + assert.NoError(t, err) + assert.Equal(t, + "idp.testshib.org", + redirectURL.Host) + assert.Equal(t, + "/idp/profile/SAML2/Redirect/SSO", + redirectURL.Path) + assert.Equal(t, + "https://15661444.ngrok.io/saml2/metadataXQ5+kdgOf34vpAemZRFalLlzjr0=Wtomi/PiWx0bMFlImy5soCrrDbdY4BR2Qb8woGqc8KsVtXAwvl6lfYE2tuoT0YS5ipPLMMsFG8dB1TmLcA+0lnUcqfBiTiiHEwTIo3193RIsoH3STlOmXqBQf9Ax2nRdX+/4HwIYF58lgUzOb+nur+zGL6mYw2xjQBw6YGaX9Cc=MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ==", + string(decodedRequest)) +} + +func TestSPFailToProduceSignedRequestWithBogusSignatureMethod(t *testing.T) { + test := NewServiceProviderTest() + TimeNow = func() time.Time { + rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 UTC 2006", "Mon Dec 1 01:31:21.123456789 UTC 2015") + return rv + } + Clock = dsig.NewFakeClockAt(TimeNow()) + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + SignatureMethod: "bogus", + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + assert.NoError(t, err) + + _, err = s.MakeRedirectAuthenticationRequest("relayState") + assert.Errorf(t, err, "invalid signing method bogus") +} + func TestSPCanProducePostLogoutRequest(t *testing.T) { test := NewServiceProviderTest() TimeNow = func() time.Time {