Skip to content

Commit

Permalink
Merge pull request #234 from stripe/saurabhbhatia/smokescreen-ctx-cha…
Browse files Browse the repository at this point in the history
…nges

Make SmokeScreen Context Fields Public
  • Loading branch information
saurabhbhatia-stripe authored Oct 11, 2024
2 parents f6f8191 + c75cffb commit 688e70b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 46 deletions.
91 changes: 46 additions & 45 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const (
type ipType int

type ACLDecision struct {
reason, role, project, outboundHost string
Reason, Role, Project, OutboundHost string
ResolvedAddr *net.TCPAddr
allow bool
enforceWouldDeny bool
Expand All @@ -79,9 +79,9 @@ type SmokescreenContext struct {
cfg *Config
start time.Time
Decision *ACLDecision
proxyType string
logger *logrus.Entry
requestedHost string
ProxyType string
Logger *logrus.Entry
RequestedHost string

// Time spent resolving the requested hostname
lookupTime time.Duration
Expand Down Expand Up @@ -257,11 +257,11 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
}
d := sctx.Decision

// If an address hasn't been resolved, does not match the original outboundHost,
// If an address hasn't been resolved, does not match the original OutboundHost,
// or is not tcp we must re-resolve it before establishing the connection.
if d.ResolvedAddr == nil || d.outboundHost != addr || network != "tcp" {
if d.ResolvedAddr == nil || d.OutboundHost != addr || network != "tcp" {
var err error
d.ResolvedAddr, d.reason, err = safeResolve(sctx.cfg, network, addr)
d.ResolvedAddr, d.Reason, err = safeResolve(sctx.cfg, network, addr)
if err != nil {
if _, ok := err.(denyError); ok {
sctx.cfg.Log.WithFields(
Expand Down Expand Up @@ -289,25 +289,25 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout)
}
connTime := time.Since(start)
sctx.logger = sctx.logger.WithFields(dialContextLoggerFields(pctx, sctx, conn, connTime))
sctx.Logger = sctx.Logger.WithFields(dialContextLoggerFields(pctx, sctx, conn, connTime))

if sctx.cfg.TimeConnect {
sctx.cfg.MetricsClient.TimingWithTags("cn.atpt.connect.time", connTime, map[string]string{"domain": sctx.requestedHost}, 1)
sctx.cfg.MetricsClient.TimingWithTags("cn.atpt.connect.time", connTime, map[string]string{"domain": sctx.RequestedHost}, 1)
}

if err != nil {
sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "false"}, 1)
sctx.cfg.ConnTracker.RecordAttempt(sctx.requestedHost, false)
sctx.cfg.ConnTracker.RecordAttempt(sctx.RequestedHost, false)
metrics.ReportConnError(sctx.cfg.MetricsClient, err)
return nil, err
}
sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "true"}, 1)
sctx.cfg.ConnTracker.RecordAttempt(sctx.requestedHost, true)
sctx.cfg.ConnTracker.RecordAttempt(sctx.RequestedHost, true)

// Only wrap CONNECT conns with an InstrumentedConn. Connections used for traditional HTTP proxy
// requests are pooled and reused by net.Transport.
if sctx.proxyType == connectProxy {
ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.logger, d.role, d.outboundHost, sctx.proxyType)
if sctx.ProxyType == connectProxy {
ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.Logger, d.Role, d.OutboundHost, sctx.ProxyType)
pctx.ConnErrorHandler = ic.Error
conn = ic
} else {
Expand Down Expand Up @@ -346,11 +346,11 @@ func HTTPErrorHandler(w io.WriteCloser, pctx *goproxy.ProxyCtx, err error) {
resp := rejectResponse(pctx, err)

if err := resp.Write(w); err != nil {
sctx.logger.Errorf("Failed to write HTTP error response: %s", err)
sctx.Logger.Errorf("Failed to write HTTP error response: %s", err)
}

if err := w.Close(); err != nil {
sctx.logger.Errorf("Failed to close proxy client connection: %s", err)
sctx.Logger.Errorf("Failed to close proxy client connection: %s", err)
}
}

Expand Down Expand Up @@ -384,12 +384,12 @@ func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response {
status = "Internal server error"
code = http.StatusInternalServerError
msg = "An unexpected error occurred: " + err.Error()
sctx.logger.WithField("error", err.Error()).Warn("rejectResponse called with unexpected error")
sctx.Logger.WithField("error", err.Error()).Warn("rejectResponse called with unexpected error")
}

// Do not double log deny errors, they are logged in a previous call to logProxy.
if _, ok := err.(denyError); !ok {
sctx.logger.Error(msg)
sctx.Logger.Error(msg)
}

if sctx.cfg.AdditionalErrorMessageOnDeny != "" {
Expand Down Expand Up @@ -438,10 +438,10 @@ func newContext(cfg *Config, proxyType string, req *http.Request) *SmokescreenCo

return &SmokescreenContext{
cfg: cfg,
logger: logger,
proxyType: proxyType,
Logger: logger,
ProxyType: proxyType,
start: start,
requestedHost: req.Host,
RequestedHost: req.Host,
}
}

Expand Down Expand Up @@ -493,7 +493,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
req.Header.Del(traceHeader)
}()

sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request")
sctx.Logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request")
// Build an address parsable by net.ResolveTCPAddr
destination, err := hostport.NewWithScheme(req.Host, req.URL.Scheme, false)
if err != nil {
Expand All @@ -510,7 +510,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
return req, rejectResponse(pctx, pctx.Error)
}
if !sctx.Decision.allow {
return req, rejectResponse(pctx, denyError{errors.New(sctx.Decision.reason)})
return req, rejectResponse(pctx, denyError{errors.New(sctx.Decision.Reason)})
}

// Call the custom request handler if it exists
Expand Down Expand Up @@ -576,7 +576,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
// We don't want to log if the connection is a MITM as it will be done in HandleConnectFunc
if pctx.ConnectAction != goproxy.ConnectMitm {
// In case of an error, this function is called a second time to filter the
// response we generate so this logger will be called once.
// response we generate so this Logger will be called once.
logProxy(pctx)
}
return resp
Expand All @@ -586,6 +586,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
// The goproxy OnResponse() function above is only called for non-https responses.
if config.AcceptResponseHandler != nil {
proxy.ConnectRespHandler = func(pctx *goproxy.ProxyCtx, resp *http.Response) error {

sctx, ok := pctx.UserData.(*SmokescreenContext)
if !ok {
return fmt.Errorf("goproxy ProxyContext missing required UserData *SmokescreenContext")
Expand All @@ -611,7 +612,7 @@ func logProxy(pctx *goproxy.ProxyCtx) {
}

if sctx.Decision != nil {
fields[LogFieldDecisionReason] = decision.reason
fields[LogFieldDecisionReason] = decision.Reason
fields[LogFieldEnforceWouldDeny] = decision.enforceWouldDeny
fields[LogFieldAllow] = decision.allow
}
Expand All @@ -621,7 +622,7 @@ func logProxy(pctx *goproxy.ProxyCtx) {
fields[LogFieldError] = err.Error()
}

entry := sctx.logger.WithFields(fields)
entry := sctx.Logger.WithFields(fields)
var logMethod func(...interface{})
if _, ok := err.(denyError); !ok && err != nil {
logMethod = entry.Error
Expand All @@ -648,16 +649,16 @@ func extractContextLogFields(pctx *goproxy.ProxyCtx, sctx *SmokescreenContext) l
// Retrieve information from the ACL decision
decision := sctx.Decision
if sctx.Decision != nil {
fields[LogFieldRole] = decision.role
fields[LogFieldProject] = decision.project
fields[LogFieldRole] = decision.Role
fields[LogFieldProject] = decision.Project
}
return fields
}

func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string, error) {
sctx := pctx.UserData.(*SmokescreenContext)

// Check if requesting role is allowed to talk to remote
// Check if requesting Role is allowed to talk to remote
destination, err := hostport.New(pctx.Req.Host, false)
if err != nil {
pctx.Error = denyError{err}
Expand All @@ -673,11 +674,11 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (*goproxy.ConnectActi
return nil, "", pctx.Error
}

// add context fields to all future log messages sent using this smokescreen context's logger
sctx.logger = sctx.logger.WithFields(extractContextLogFields(pctx, sctx))
// add context fields to all future log messages sent using this smokescreen context's Logger
sctx.Logger = sctx.Logger.WithFields(extractContextLogFields(pctx, sctx))

if !sctx.Decision.allow {
return nil, "", denyError{errors.New(sctx.Decision.reason)}
return nil, "", denyError{errors.New(sctx.Decision.Reason)}
}

// Call the custom request handler if it exists
Expand All @@ -696,7 +697,7 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (*goproxy.ConnectActi
deny := denyError{errors.New("ACLDecision specified MITM but Smokescreen doesn't have MITM enabled")}
sctx.Decision.allow = false
sctx.Decision.MitmConfig = nil
sctx.Decision.reason = deny.Error()
sctx.Decision.Reason = deny.Error()
return nil, "", deny
}
mitm := sctx.Decision.MitmConfig
Expand Down Expand Up @@ -912,10 +913,10 @@ func runServer(config *Config, server *http.Server, listener net.Listener, quit
}
}

// Extract the client's ACL role from the HTTP request, using the configured
// RoleFromRequest function. Returns the role, or an error if the role cannot
// Extract the client's ACL Role from the HTTP request, using the configured
// RoleFromRequest function. Returns the Role, or an error if the Role cannot
// be determined (including no RoleFromRequest configured), unless
// AllowMissingRole is configured, in which case an empty role and no error is
// AllowMissingRole is configured, in which case an empty Role and no error is
// returned.
func getRole(config *Config, req *http.Request) (string, error) {
var role string
Expand Down Expand Up @@ -955,7 +956,7 @@ func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destinatio
if _, ok := err.(denyError); !ok {
return decision, lookupTime, err
}
decision.reason = fmt.Sprintf("%s. %s", err.Error(), reason)
decision.Reason = fmt.Sprintf("%s. %s", err.Error(), reason)
decision.allow = false
decision.enforceWouldDeny = true
} else {
Expand All @@ -968,28 +969,28 @@ func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destinatio

func checkACLsForRequest(config *Config, req *http.Request, destination hostport.HostPort) *ACLDecision {
decision := &ACLDecision{
outboundHost: destination.String(),
OutboundHost: destination.String(),
}

if config.EgressACL == nil {
decision.allow = true
decision.reason = "Egress ACL is not configured"
decision.Reason = "Egress ACL is not configured"
return decision
}

role, roleErr := getRole(config, req)
if roleErr != nil {
config.MetricsClient.Incr("acl.role_not_determined", 1)
decision.reason = "Client role cannot be determined"
decision.Reason = "Client role cannot be determined"
return decision
}

decision.role = role
decision.Role = role

// This host validation prevents IPv6 addresses from being used as destinations.
// Added for backwards compatibility.
if strings.ContainsAny(destination.Host, ":") {
decision.reason = "Destination host cannot be determined"
decision.Reason = "Destination host cannot be determined"
return decision
}

Expand Down Expand Up @@ -1024,8 +1025,8 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
}

ACLDecision, err := config.EgressACL.Decide(role, destination.Host, connectProxyHost)
decision.project = ACLDecision.Project
decision.reason = ACLDecision.Reason
decision.Project = ACLDecision.Project
decision.Reason = ACLDecision.Reason
decision.MitmConfig = ACLDecision.MitmConfig
if err != nil {
config.Log.WithFields(logrus.Fields{
Expand All @@ -1038,7 +1039,7 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
}

tags := map[string]string{
"role": decision.role,
"role": decision.Role,
"def_rule": fmt.Sprintf("%t", ACLDecision.Default),
"project": ACLDecision.Project,
}
Expand All @@ -1064,7 +1065,7 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
"destination": destination.Host,
"action": ACLDecision.Result.String(),
}).Warn("Unknown ACL action")
decision.reason = "Internal error"
decision.Reason = "Internal error"
config.MetricsClient.IncrWithTags("acl.unknown_error", tags, 1)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func TestProxyTimeouts(t *testing.T) {
// for an EOF returned from HTTP client to indicate a connection interruption
// which in our case represents the timeout.
//
// To correctly hook into this, we'd need to pass a logger from Smokescreen to Goproxy
// To correctly hook into this, we'd need to pass a Logger from Smokescreen to Goproxy
// which we have hooks into. This would be able to verify the timeout as errors from
// each end of the connection pair are logged by Goproxy.
t.Run("CONNECT proxy timeouts", func(t *testing.T) {
Expand Down

0 comments on commit 688e70b

Please sign in to comment.