Skip to content

Commit

Permalink
Add Support for Reject Handler with Context
Browse files Browse the repository at this point in the history
  • Loading branch information
saurabhbhatia-stripe committed Oct 3, 2024
1 parent dab4bde commit 3713647
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pkg/smokescreen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ type Config struct {
// Custom handler to allow clients to modify reject responses
RejectResponseHandler func(*http.Response)

// Custom handler to allow clients to modify reject responses
RejectResponseHandlerWithCtx func(*SmokescreenContext, *http.Response)

// Custom handler to allow clients to modify successful CONNECT responses
AcceptResponseHandler func(*SmokescreenContext, *http.Response) error

Expand Down
3 changes: 3 additions & 0 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response {
if sctx.cfg.RejectResponseHandler != nil {
sctx.cfg.RejectResponseHandler(resp)
}
if sctx.cfg.RejectResponseHandlerWithCtx != nil {
sctx.cfg.RejectResponseHandlerWithCtx(sctx, resp)
}
return resp
}

Expand Down
41 changes: 41 additions & 0 deletions pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,47 @@ func TestRejectResponseHandler(t *testing.T) {
})
}

func TestRejectResponseHandlerWithCtx(t *testing.T) {
r := require.New(t)
testHeader := "TestRejectResponseHandlerWithCtxHeader"
t.Run("Testing custom reject response handler", func(t *testing.T) {
cfg, err := testConfig("test-local-srv")

// set a custom RejectResponseHandler that will set a header on every reject response
cfg.RejectResponseHandlerWithCtx = func(_ *SmokescreenContext, resp *http.Response) {
resp.Header.Set(testHeader, "This header is added by the RejectResponseHandlerWithCtx")
}
r.NoError(err)

proxySrv := proxyServer(cfg)
r.NoError(err)
defer proxySrv.Close()

// Create a http.Client that uses our proxy
client, err := proxyClient(proxySrv.URL)
r.NoError(err)

// Send a request that should be blocked
resp, err := client.Get("http://127.0.0.1")
r.NoError(err)

// The RejectResponseHandler should set our custom header
h := resp.Header.Get(testHeader)
if h == "" {
t.Errorf("Expecting header %s to be set by RejectResponseHandler", testHeader)
}
// Send a request that should be allowed
resp, err = client.Get("http://example.com")
r.NoError(err)

// The header set by our custom reject response handler should not be set
h = resp.Header.Get(testHeader)
if h != "" {
t.Errorf("Expecting header %s to not be set by RejectResponseHandler", testHeader)
}
})
}

// Test that Smokescreen calls the custom accept response handler (if defined in the Config struct)
// after every accepted request
func TestAcceptResponseHandler(t *testing.T) {
Expand Down

0 comments on commit 3713647

Please sign in to comment.