From 4a651a6ab7fc209c0dc706a6c81d7df92e8e9467 Mon Sep 17 00:00:00 2001 From: Ross Kinder Date: Sun, 7 Jan 2018 21:11:02 -0500 Subject: [PATCH] samlsp: move the setting and reading of cookies into an interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We’ve had a bunch of changes requesting the ability to customize how cookies are set and it is getting a little messy. This change moves the code to setting and reading cookies into two interfaces which you can extend/customize. --- samlsp/cookie.go | 105 ++++++++++++++++++++++++++++++++++++++ samlsp/middleware.go | 67 +++++++----------------- samlsp/middleware_test.go | 12 +++-- samlsp/samlsp.go | 21 +++++--- 4 files changed, 146 insertions(+), 59 deletions(-) create mode 100644 samlsp/cookie.go diff --git a/samlsp/cookie.go b/samlsp/cookie.go new file mode 100644 index 00000000..f05c0b20 --- /dev/null +++ b/samlsp/cookie.go @@ -0,0 +1,105 @@ +package samlsp + +import ( + "net/http" + "strings" + "time" + + "github.com/crewjam/saml" +) + +// ClientState implements client side storage for state. +type ClientState interface { + SetState(w http.ResponseWriter, r *http.Request, id string, value string) + GetStates(r *http.Request) map[string]string + GetState(r *http.Request, id string) string + DeleteState(w http.ResponseWriter, r *http.Request, id string) error +} + +// ClientToken implements client side storage for signed authorization tokens. +type ClientToken interface { + GetToken(r *http.Request) string + SetToken(w http.ResponseWriter, r *http.Request, value string, maxAge time.Duration) +} + +const stateCookiePrefix = "saml_" +const defaultCookieName = "token" + +// ClientCookies implements ClientState and ClientToken using cookies. +type ClientCookies struct { + ServiceProvider *saml.ServiceProvider + Name string + Domain string + Secure bool +} + +// SetState stores the named state value by setting a cookie. +func (c ClientCookies) SetState(w http.ResponseWriter, r *http.Request, id string, value string) { + http.SetCookie(w, &http.Cookie{ + Name: stateCookiePrefix + id, + Value: value, + MaxAge: int(saml.MaxIssueDelay.Seconds()), + HttpOnly: true, + Secure: c.Secure || r.URL.Scheme == "https", + Path: c.ServiceProvider.AcsURL.Path, + }) +} + +// GetStates returns the currently stored states by reading cookies. +func (c ClientCookies) GetStates(r *http.Request) map[string]string { + rv := map[string]string{} + for _, cookie := range r.Cookies() { + if !strings.HasPrefix(cookie.Name, stateCookiePrefix) { + continue + } + name := strings.TrimPrefix(cookie.Name, stateCookiePrefix) + rv[name] = cookie.Value + } + return rv +} + +// GetState returns a single stored state by reading the cookies +func (c ClientCookies) GetState(r *http.Request, id string) string { + stateCookie, err := r.Cookie(stateCookiePrefix + id) + if err != nil { + return "" + } + return stateCookie.Value +} + +// DeleteState removes the named stored state by clearing the corresponding cookie. +func (c ClientCookies) DeleteState(w http.ResponseWriter, r *http.Request, id string) error { + cookie, err := r.Cookie(stateCookiePrefix + id) + if err != nil { + return err + } + cookie.Value = "" + cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{} + http.SetCookie(w, cookie) + return nil +} + +// SetToken assigns the specified token by setting a cookie. +func (c ClientCookies) SetToken(w http.ResponseWriter, r *http.Request, value string, maxAge time.Duration) { + http.SetCookie(w, &http.Cookie{ + Name: c.Name, + Domain: c.Domain, + Value: value, + MaxAge: int(maxAge.Seconds()), + HttpOnly: true, + Secure: c.Secure || r.URL.Scheme == "https", + Path: "/", + }) +} + +// GetToken returns the token by reading the cookie. +func (c ClientCookies) GetToken(r *http.Request) string { + cookie, err := r.Cookie(c.Name) + if err != nil { + return "" + } + return cookie.Value +} + +var _ ClientState = ClientCookies{} +var _ ClientToken = ClientCookies{} diff --git a/samlsp/middleware.go b/samlsp/middleware.go index acb44c3c..430c6042 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -4,9 +4,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/xml" - "fmt" "net/http" - "strings" "time" "github.com/crewjam/saml" @@ -47,15 +45,11 @@ import ( type Middleware struct { ServiceProvider saml.ServiceProvider AllowIDPInitiated bool - CookieName string - CookieMaxAge time.Duration - CookieDomain string - CookieSecure bool + TokenMaxAge time.Duration + ClientState ClientState + ClientToken ClientToken } -const defaultCookieMaxAge = time.Hour -const defaultCookieName = "token" - var jwtSigningMethod = jwt.SigningMethodHS256 func randomBytes(n int) []byte { @@ -145,15 +139,7 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler { return } - http.SetCookie(w, &http.Cookie{ - Name: fmt.Sprintf("saml_%s", relayState), - Value: signedState, - MaxAge: int(saml.MaxIssueDelay.Seconds()), - HttpOnly: true, - Secure: m.CookieSecure || r.URL.Scheme == "https", - Path: m.ServiceProvider.AcsURL.Path, - }) - + m.ClientState.SetState(w, r, relayState, signedState) if binding == saml.HTTPRedirectBinding { redirectURL := req.Redirect(relayState) w.Header().Add("Location", redirectURL.String()) @@ -178,16 +164,11 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler { func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string { rv := []string{} - for _, cookie := range r.Cookies() { - if !strings.HasPrefix(cookie.Name, "saml_") { - continue - } - m.ServiceProvider.Logger.Printf("getPossibleRequestIDs: cookie: %s", cookie.String()) - + for _, value := range m.ClientState.GetStates(r) { jwtParser := jwt.Parser{ ValidMethods: []string{jwtSigningMethod.Name}, } - token, err := jwtParser.Parse(cookie.Value, func(t *jwt.Token) (interface{}, error) { + token, err := jwtParser.Parse(value, func(t *jwt.Token) (interface{}, error) { secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key) return secretBlock, nil }) @@ -214,10 +195,10 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key) redirectURI := "/" - if r.Form.Get("RelayState") != "" { - stateCookie, err := r.Cookie(fmt.Sprintf("saml_%s", r.Form.Get("RelayState"))) - if err != nil { - m.ServiceProvider.Logger.Printf("cannot find corresponding cookie: %s", fmt.Sprintf("saml_%s", r.Form.Get("RelayState"))) + if relayState := r.Form.Get("RelayState"); relayState != "" { + stateValue := m.ClientState.GetState(r, relayState) + if stateValue == "" { + m.ServiceProvider.Logger.Printf("cannot find corresponding state: %s", relayState) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } @@ -225,11 +206,11 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion jwtParser := jwt.Parser{ ValidMethods: []string{jwtSigningMethod.Name}, } - state, err := jwtParser.Parse(stateCookie.Value, func(t *jwt.Token) (interface{}, error) { + state, err := jwtParser.Parse(stateValue, func(t *jwt.Token) (interface{}, error) { return secretBlock, nil }) if err != nil || !state.Valid { - m.ServiceProvider.Logger.Printf("Cannot decode state JWT: %s (%s)", err, stateCookie.Value) + m.ServiceProvider.Logger.Printf("Cannot decode state JWT: %s (%s)", err, stateValue) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } @@ -237,16 +218,14 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion redirectURI = claims["uri"].(string) // delete the cookie - stateCookie.Value = "" - stateCookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{} - http.SetCookie(w, stateCookie) + m.ClientState.DeleteState(w, r, relayState) } now := saml.TimeNow() claims := AuthorizationToken{} claims.Audience = m.ServiceProvider.Metadata().EntityID claims.IssuedAt = now.Unix() - claims.ExpiresAt = now.Add(m.CookieMaxAge).Unix() + claims.ExpiresAt = now.Add(m.TokenMaxAge).Unix() claims.NotBefore = now.Unix() if sub := assertion.Subject; sub != nil { if nameID := sub.NameID; nameID != nil { @@ -265,23 +244,13 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion } } } - signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBlock) if err != nil { panic(err) } - http.SetCookie(w, &http.Cookie{ - Name: m.CookieName, - Domain: m.CookieDomain, - Value: signedToken, - MaxAge: int(m.CookieMaxAge.Seconds()), - HttpOnly: true, - Secure: m.CookieSecure || r.URL.Scheme == "https", - Path: "/", - }) - + m.ClientToken.SetToken(w, r, signedToken, m.TokenMaxAge) http.Redirect(w, r, redirectURI, http.StatusFound) } @@ -298,13 +267,13 @@ func (m *Middleware) IsAuthorized(r *http.Request) bool { // SAML login flow. If the request is authorized, then the request context is // ammended with a Context object. func (m *Middleware) GetAuthorizationToken(r *http.Request) *AuthorizationToken { - cookie, err := r.Cookie(m.CookieName) - if err != nil { + tokenStr := m.ClientToken.GetToken(r) + if tokenStr == "" { return nil } tokenClaims := AuthorizationToken{} - token, err := jwt.ParseWithClaims(cookie.Value, &tokenClaims, func(t *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenStr, &tokenClaims, func(t *jwt.Token) (interface{}, error) { secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key) return secretBlock, nil }) diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index 4c4fbd07..7ce2a6df 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -75,9 +75,14 @@ func (test *MiddlewareTest) SetUpTest(c *C) { IDPMetadata: &saml.EntityDescriptor{}, Logger: logger.DefaultLogger, }, - CookieName: "ttt", - CookieMaxAge: time.Hour * 2, + TokenMaxAge: time.Hour * 2, } + cookieStore := ClientCookies{ + ServiceProvider: &test.Middleware.ServiceProvider, + Name: "ttt", + } + test.Middleware.ClientState = &cookieStore + test.Middleware.ClientToken = &cookieStore err := xml.Unmarshal([]byte(test.IDPMetadata), &test.Middleware.ServiceProvider.IDPMetadata) c.Assert(err, IsNil) } @@ -149,7 +154,8 @@ func (test *MiddlewareTest) TestRequireAccountNoCreds(c *C) { } func (test *MiddlewareTest) TestRequireAccountNoCredsSecure(c *C) { - test.Middleware.CookieSecure = true + cookieStore := test.Middleware.ClientState.(*ClientCookies) + cookieStore.Secure = true handler := test.Middleware.RequireAccount( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("not reached") diff --git a/samlsp/samlsp.go b/samlsp/samlsp.go index 4e54007b..1d271f7c 100644 --- a/samlsp/samlsp.go +++ b/samlsp/samlsp.go @@ -16,6 +16,8 @@ import ( "github.com/crewjam/saml/logger" ) +const defaultTokenMaxAge = time.Hour + // Options represents the parameters for creating a new middleware type Options struct { URL url.URL @@ -33,7 +35,6 @@ type Options struct { // New creates a new Middleware func New(opts Options) (*Middleware, error) { - metadataURL := opts.URL metadataURL.Path = metadataURL.Path + "/saml/metadata" acsURL := opts.URL @@ -43,9 +44,9 @@ func New(opts Options) (*Middleware, error) { logr = logger.DefaultLogger } - cookieMaxAge := opts.CookieMaxAge + tokenMaxAge := opts.CookieMaxAge if opts.CookieMaxAge == 0 { - cookieMaxAge = defaultCookieMaxAge + tokenMaxAge = defaultTokenMaxAge } m := &Middleware{ @@ -59,11 +60,17 @@ func New(opts Options) (*Middleware, error) { ForceAuthn: &opts.ForceAuthn, }, AllowIDPInitiated: opts.AllowIDPInitiated, - CookieName: defaultCookieName, - CookieMaxAge: cookieMaxAge, - CookieDomain: opts.URL.Host, - CookieSecure: opts.CookieSecure, + TokenMaxAge: tokenMaxAge, + } + + cookieStore := ClientCookies{ + ServiceProvider: &m.ServiceProvider, + Name: defaultCookieName, + Domain: opts.URL.Host, + Secure: opts.CookieSecure, } + m.ClientState = &cookieStore + m.ClientToken = &cookieStore // fetch the IDP metadata if needed. if opts.IDPMetadataURL == nil {