Skip to content

Commit

Permalink
ModifyResponse feature to be able to use secure with http.ReverseProxy (
Browse files Browse the repository at this point in the history
#31)

* Add .idea directory to .gitignore

* fix: go fmt

* feat: Add ModifyResponseHeaders function to be able to use middleware with http.ReverseProxy

* Restore gitignore

* fix review + add tests

* fix review comments

* fix: use request context to save response header to avoid race
  • Loading branch information
mmatur authored and unrolled committed Feb 26, 2018
1 parent 5b5ec9d commit 7fac758
Show file tree
Hide file tree
Showing 4 changed files with 429 additions and 28 deletions.
4 changes: 2 additions & 2 deletions csp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ func CSPNonce(c context.Context) string {
// WithCSPNonce returns a context derived from ctx containing the given nonce as a value.
//
// This is intended for testing or more advanced use-cases;
// for ordinary HTTP handlers, clients can rely on this package's middleware to populate the CSP nonce in the context.
// for ordinary HTTP handlers, clients can rely on this package's middleware to populate the CSP nonce in the context.
func WithCSPNonce(ctx context.Context, nonce string) context.Context {
return context.WithValue(ctx, cspNonceKey, nonce)
return context.WithValue(ctx, cspNonceKey, nonce)
}

func withCSPNonce(r *http.Request, nonce string) *http.Request {
Expand Down
132 changes: 106 additions & 26 deletions secure.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package secure

import (
"context"
"fmt"
"net/http"
"strings"
Expand All @@ -20,7 +21,8 @@ const (
hpkpHeader = "Public-Key-Pins"
referrerPolicyHeader = "Referrer-Policy"

cspNonceSize = 16
ctxSecureHeaderKey = "SecureResponseHeader"
cspNonceSize = 16
)

func defaultBadHostHandler(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -128,6 +130,29 @@ func (s *Secure) Handler(h http.Handler) http.Handler {
})
}

// HandlerForRequestOnly implements the http.HandlerFunc for integration with the standard net/http lib.
func (s *Secure) HandlerForRequestOnly(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.opt.nonceEnabled {
r = withCSPNonce(r, cspRandNonce())
}

// Let secure process the request. If it returns an error,
// that indicates the request should not continue.
responseHeader, err := s.processRequest(w, r)

// If there was an error, do not continue.
if err != nil {
return
}

// Save response headers in the request context
ctx := context.WithValue(r.Context(), ctxSecureHeaderKey, responseHeader)
// No headers will be written to the ResponseWriter.
h.ServeHTTP(w, r.WithContext(ctx))
})
}

// HandlerFuncWithNext is a special implementation for Negroni, but could be used elsewhere.
func (s *Secure) HandlerFuncWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if s.opt.nonceEnabled {
Expand All @@ -144,8 +169,40 @@ func (s *Secure) HandlerFuncWithNext(w http.ResponseWriter, r *http.Request, nex
}
}

// Process runs the actual checks and returns an error if the middleware chain should stop.
// HandlerFuncWithNextForRequestOnly is a special implementation for Negroni, but could be used elsewhere.
func (s *Secure) HandlerFuncWithNextForRequestOnly(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if s.opt.nonceEnabled {
r = withCSPNonce(r, cspRandNonce())
}

// Let secure process the request. If it returns an error,
// that indicates the request should not continue.
responseHeader, err := s.processRequest(w, r)

// If there was an error, do not call next.
if err == nil && next != nil {
// Save response headers in the request context
ctx := context.WithValue(r.Context(), ctxSecureHeaderKey, responseHeader)
// No headers will be written to the ResponseWriter.
next(w, r.WithContext(ctx))
}
}

// Process runs the actual checks and writes the headers in the ResponseWriter.
func (s *Secure) Process(w http.ResponseWriter, r *http.Request) error {
responseHeader, err := s.processRequest(w, r)
if responseHeader != nil {
for key, values := range responseHeader {
for _, value := range values {
w.Header().Add(key, value)
}
}
}
return err
}

// processRequest runs the actual checks on the request and returns an error if the middleware chain should stop.
func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.Header, error) {
// Resolve the host for the request, using proxy headers if present.
host := r.Host
for _, header := range s.opt.HostsProxyHeaders {
Expand All @@ -167,23 +224,15 @@ func (s *Secure) Process(w http.ResponseWriter, r *http.Request) error {

if !isGoodHost {
s.badHostHandler.ServeHTTP(w, r)
return fmt.Errorf("Bad host name: %s", host)
return nil, fmt.Errorf("bad host name: %s", host)
}
}

// Determine if we are on HTTPS.
isSSL := strings.EqualFold(r.URL.Scheme, "https") || r.TLS != nil
if !isSSL {
for k, v := range s.opt.SSLProxyHeaders {
if r.Header.Get(k) == v {
isSSL = true
break
}
}
}
ssl := s.isSSL(r)

// SSL check.
if s.opt.SSLRedirect && !isSSL && !s.opt.IsDevelopment {
if s.opt.SSLRedirect && !ssl && !s.opt.IsDevelopment {
url := r.URL
url.Scheme = "https"
url.Host = host
Expand All @@ -198,12 +247,13 @@ func (s *Secure) Process(w http.ResponseWriter, r *http.Request) error {
}

http.Redirect(w, r, url.String(), status)
return fmt.Errorf("Redirecting to HTTPS")
return nil, fmt.Errorf("redirecting to HTTPS")
}

responseHeader := make(http.Header)
// Strict Transport Security header. Only add header when we know it's an SSL connection.
// See https://tools.ietf.org/html/rfc6797#section-7.2 for details.
if s.opt.STSSeconds != 0 && (isSSL || s.opt.ForceSTSHeader) && !s.opt.IsDevelopment {
if s.opt.STSSeconds != 0 && (ssl || s.opt.ForceSTSHeader) && !s.opt.IsDevelopment {
stsSub := ""
if s.opt.STSIncludeSubdomains {
stsSub = stsSubdomainString
Expand All @@ -213,46 +263,76 @@ func (s *Secure) Process(w http.ResponseWriter, r *http.Request) error {
stsSub += stsPreloadString
}

w.Header().Add(stsHeader, fmt.Sprintf("max-age=%d%s", s.opt.STSSeconds, stsSub))
responseHeader.Set(stsHeader, fmt.Sprintf("max-age=%d%s", s.opt.STSSeconds, stsSub))
}

// Frame Options header.
if len(s.opt.CustomFrameOptionsValue) > 0 {
w.Header().Add(frameOptionsHeader, s.opt.CustomFrameOptionsValue)
responseHeader.Set(frameOptionsHeader, s.opt.CustomFrameOptionsValue)
} else if s.opt.FrameDeny {
w.Header().Add(frameOptionsHeader, frameOptionsValue)
responseHeader.Set(frameOptionsHeader, frameOptionsValue)
}

// Content Type Options header.
if s.opt.ContentTypeNosniff {
w.Header().Add(contentTypeHeader, contentTypeValue)
responseHeader.Set(contentTypeHeader, contentTypeValue)
}

// XSS Protection header.
if len(s.opt.CustomBrowserXssValue) > 0 {
w.Header().Add(xssProtectionHeader, s.opt.CustomBrowserXssValue)
responseHeader.Set(xssProtectionHeader, s.opt.CustomBrowserXssValue)
} else if s.opt.BrowserXssFilter {
w.Header().Add(xssProtectionHeader, xssProtectionValue)
responseHeader.Set(xssProtectionHeader, xssProtectionValue)
}

// HPKP header.
if len(s.opt.PublicKey) > 0 && isSSL && !s.opt.IsDevelopment {
w.Header().Add(hpkpHeader, s.opt.PublicKey)
if len(s.opt.PublicKey) > 0 && ssl && !s.opt.IsDevelopment {
responseHeader.Set(hpkpHeader, s.opt.PublicKey)
}

// Content Security Policy header.
if len(s.opt.ContentSecurityPolicy) > 0 {
if s.opt.nonceEnabled {
w.Header().Add(cspHeader, fmt.Sprintf(s.opt.ContentSecurityPolicy, CSPNonce(r.Context())))
responseHeader.Set(cspHeader, fmt.Sprintf(s.opt.ContentSecurityPolicy, CSPNonce(r.Context())))
} else {
w.Header().Add(cspHeader, s.opt.ContentSecurityPolicy)
responseHeader.Set(cspHeader, s.opt.ContentSecurityPolicy)
}
}

// Referrer Policy header.
if len(s.opt.ReferrerPolicy) > 0 {
w.Header().Add(referrerPolicyHeader, s.opt.ReferrerPolicy)
responseHeader.Set(referrerPolicyHeader, s.opt.ReferrerPolicy)
}

return responseHeader, nil
}

// isSSL determine if we are on HTTPS.
func (s *Secure) isSSL(r *http.Request) bool {
ssl := strings.EqualFold(r.URL.Scheme, "https") || r.TLS != nil
if !ssl {
for k, v := range s.opt.SSLProxyHeaders {
if r.Header.Get(k) == v {
ssl = true
break
}
}
}
return ssl
}

// ModifyResponseHeaders modifies the Response.
// Used by http.ReverseProxy.
func (s *Secure) ModifyResponseHeaders(res *http.Response) error {
if res != nil && res.Request != nil {
responseHeader := res.Request.Context().Value(ctxSecureHeaderKey)
if responseHeader != nil {
for header, values := range responseHeader.(http.Header) {
if len(values) > 0 {
res.Header.Set(header, strings.Join(values, ","))
}
}
}
}
return nil
}
25 changes: 25 additions & 0 deletions secure_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,28 @@ func TestIntegrationWithError(t *testing.T) {

expect(t, res.Code, http.StatusInternalServerError)
}

func TestIntegrationForRequestOnly(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "bar")
})

secureMiddleware := New(Options{
ContentTypeNosniff: true,
FrameDeny: true,
})

n := negroni.New()
n.Use(negroni.HandlerFunc(secureMiddleware.HandlerFuncWithNextForRequestOnly))
n.UseHandler(mux)

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
n.ServeHTTP(res, req)

expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), "bar")
expect(t, res.Header().Get("X-Frame-Options"), "")
expect(t, res.Header().Get("X-Content-Type-Options"), "")
}
Loading

0 comments on commit 7fac758

Please sign in to comment.