Skip to content

Commit

Permalink
chore: Remove redis init side-effect of app.NewHandler (#64062)
Browse files Browse the repository at this point in the history
This current code sets a global variable which couples the session
package and the app package and the session package cannot work when the
NewHandler method wasn't called yet.

This also used to wait for redis implicitly, without any indication of
doing that.

To be more explicit about that, we now do that in the main function of
frontend instead, and create the session store where required on the
fly.

Test plan: Login still works, integration/E2E tests are passing.
  • Loading branch information
eseliger committed Jul 31, 2024
1 parent 03c2907 commit 7a7c663
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 122 deletions.
5 changes: 0 additions & 5 deletions cmd/frontend/internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/app/router"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/app/ui"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/auth/accessrequest"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/auth/session"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/auth/userpasswd"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
Expand All @@ -22,10 +21,6 @@ import (
// 🚨 SECURITY: The caller MUST wrap the returned handler in middleware that checks authentication
// and sets the actor in the request context.
func NewHandler(db database.DB, logger log.Logger) http.Handler {
session.SetSessionStore(session.NewRedisStore(func() bool {
return conf.ExternalURLParsed().Scheme == "https"
}))

logger = logger.Scoped("appHandler")

r := router.Router()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import (
// various endpoints, but does NOT cover the logic that is contained within `golang.org/x/oauth2`
// and `github.com/dghubble/gologin` which ensures the correctness of the `/callback` handler.
func TestMiddleware(t *testing.T) {
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
session.ResetMockSessionStore(t)

db := dbmocks.NewMockDB()

Expand Down
3 changes: 1 addition & 2 deletions cmd/frontend/internal/auth/githuboauth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import (
// and `github.com/dghubble/gologin` which ensures the correctness of the `/callback` handler.
func TestMiddleware(t *testing.T) {
logger := logtest.Scoped(t)
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
session.ResetMockSessionStore(t)

db := database.NewDB(logger, dbtest.NewDB(t))

Expand Down
3 changes: 1 addition & 2 deletions cmd/frontend/internal/auth/gitlaboauth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ import (
// and `github.com/dghubble/gologin` which ensures the correctness of the `/callback` handler.
func TestMiddleware(t *testing.T) {
logger := logtest.Scoped(t)
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
session.ResetMockSessionStore(t)

const mockUserID = 123

Expand Down
6 changes: 2 additions & 4 deletions cmd/frontend/internal/auth/openidconnect/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ func newOIDCIDServer(t *testing.T, code string, oidcProvider *schema.OpenIDConne

func TestMiddleware(t *testing.T) {
logger := logtest.Scoped(t)
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
session.ResetMockSessionStore(t)
defer licensing.TestingSkipFeatureChecks()()

mockGetProviderValue = &Provider{
Expand Down Expand Up @@ -327,8 +326,7 @@ func TestMiddleware(t *testing.T) {

func TestMiddleware_NoOpenRedirect(t *testing.T) {
logger := logtest.Scoped(t)
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
session.ResetMockSessionStore(t)

defer licensing.TestingSkipFeatureChecks()()

Expand Down
3 changes: 1 addition & 2 deletions cmd/frontend/internal/auth/saml/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ func TestMiddleware(t *testing.T) {
providers.MockProviders = []providers.Provider{mockGetProviderValue}
defer func() { providers.MockProviders = nil }()

cleanup := session.ResetMockSessionStore(t)
defer cleanup()
session.ResetMockSessionStore(t)

providerID := providerConfigID(&mockGetProviderValue.config, true)

Expand Down
1 change: 0 additions & 1 deletion cmd/frontend/internal/auth/session/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ go_library(
"@com_github_boj_redistore//:redistore",
"@com_github_gorilla_securecookie//:securecookie",
"@com_github_gorilla_sessions//:sessions",
"@com_github_inconshreveable_log15//:log15",
"@com_github_sourcegraph_log//:log",
"@io_opentelemetry_go_otel//attribute",
],
Expand Down
108 changes: 34 additions & 74 deletions cmd/frontend/internal/auth/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"strings"
"time"

"github.com/boj/redistore"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/sourcegraph/log"
"go.opentelemetry.io/otel/attribute"

Expand All @@ -25,17 +28,11 @@ import (
"github.com/sourcegraph/sourcegraph/internal/trace"
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/lib/errors"

"github.com/inconshreveable/log15" //nolint:logging // TODO move all logging to sourcegraph/log

"github.com/boj/redistore"
"github.com/gorilla/sessions"
)

const SignOutCookie = "sg-signout"

var (
sessionStore sessions.Store
sessionCookieKey = env.Get("SRC_SESSION_COOKIE_KEY", "", "secret key used for securing the session cookies")
)

Expand All @@ -54,11 +51,6 @@ type sessionInfo struct {
UserCreatedAt time.Time `json:"userCreatedAt"`
}

// SetSessionStore sets the backing store used for storing sessions on the server. It should be called exactly once.
func SetSessionStore(s sessions.Store) {
sessionStore = s
}

// sessionsStore wraps another sessions.Store to dynamically set the values
// of the session.Options.Secure and session.Options.SameSite fields to what
// is returned by the secure closure at invocation time.
Expand Down Expand Up @@ -90,25 +82,36 @@ func (st *sessionsStore) setSecureOptions(s *sessions.Session) {
}
}

// NewRedisStore creates a new session store backed by Redis.
func NewRedisStore(secureCookie func() bool) sessions.Store {
var store sessions.Store
var options *sessions.Options
var mockSessionStore sessions.Store

pool := redispool.Store.Pool()
rstore, err := redistore.NewRediStoreWithPool(pool, []byte(sessionCookieKey))
if err != nil {
waitForRedis(rstore)
// newSessionStore creates a new session store backed by Redis.
func newSessionStore() sessions.Store {
if mockSessionStore != nil {
return mockSessionStore
}

rstore := &redistore.RediStore{
Pool: redispool.Store.Pool(),
Codecs: securecookie.CodecsFromPairs([]byte(sessionCookieKey)),
Options: &sessions.Options{
Path: "/",
HttpOnly: true,
MaxAge: 86400 * 30, // 30 days, default of the library
},
DefaultMaxAge: 60 * 20, // 20 minutes seems like a reasonable default
}
rstore.SetMaxLength(4096)
rstore.SetSerializer(redistore.GobSerializer{})
rstore.SetKeyPrefix("session_")

secureCookie := func() bool {
return conf.ExternalURLParsed().Scheme == "https"
}
store = rstore
options = rstore.Options

options.Path = "/"
options.HttpOnly = true
setSessionSecureOptions(rstore.Options, secureCookie())

setSessionSecureOptions(options, secureCookie())
return &sessionsStore{
Store: store,
Store: rstore,
secure: secureCookie,
}
}
Expand All @@ -133,60 +136,13 @@ func setSessionSecureOptions(opts *sessions.Options, secure bool) {
opts.Secure = secure
}

// Ping attempts to contact Redis and returns a non-nil error upon failure. It is intended to be
// used by health checks.
func Ping() error {
if sessionStore == nil {
return errors.New("redis session store is not available")
}
rstore, ok := sessionStore.(*redistore.RediStore)
if !ok {
// Only try to ping Redis session stores. If we add other types of session stores, add ways
// to ping them here.
return nil
}
return ping(rstore)
}

func ping(s *redistore.RediStore) error {
conn := s.Pool.Get()
defer conn.Close()
data, err := conn.Do("PING")
if err != nil {
return err
}
if data != "PONG" {
return errors.New("no pong received")
}
return nil
}

// waitForRedis waits up to a certain timeout for Redis to become reachable, to reduce the
// likelihood of the HTTP handlers starting to serve requests while Redis (and therefore session
// data) is still unavailable. After the timeout has elapsed, if Redis is still unreachable, it
// continues anyway (because that's probably better than the site not coming up at all).
func waitForRedis(s *redistore.RediStore) {
const timeout = 5 * time.Second
deadline := time.Now().Add(timeout)
var err error
for {
time.Sleep(150 * time.Millisecond)
err = ping(s)
if err == nil {
return
}
if time.Now().After(deadline) {
log15.Warn("Redis (used for session store) failed to become reachable. Will continue trying to establish connection in background.", "timeout", timeout, "error", err)
return
}
}
}

// SetData sets the session data at the key. The session data is a map of keys to values. If no
// session exists, a new session is created.
//
// The value is JSON-encoded before being stored.
func SetData(w http.ResponseWriter, r *http.Request, key string, value any) error {
sessionStore := newSessionStore()

session, err := sessionStore.Get(r, cookieName)
if err != nil {
return errors.WithMessage(err, "getting session")
Expand All @@ -207,6 +163,8 @@ func SetData(w http.ResponseWriter, r *http.Request, key string, value any) erro
//
// The value is JSON-decoded from the raw bytes stored by the call to SetData.
func GetData(r *http.Request, key string, value any) error {
sessionStore := newSessionStore()

session, err := sessionStore.Get(r, cookieName)
if err != nil {
return errors.WithMessage(err, "getting session")
Expand Down Expand Up @@ -298,6 +256,8 @@ func deleteSession(w http.ResponseWriter, r *http.Request) error {
return nil // nothing to do
}

sessionStore := newSessionStore()

session, err := sessionStore.Get(r, cookieName)
session.Options.MaxAge = -1 // expire immediately
if err == nil {
Expand Down
29 changes: 10 additions & 19 deletions cmd/frontend/internal/auth/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import (
func TestSetActorDeleteSession(t *testing.T) {
logger := logtest.Scoped(t)

cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

userCreatedAt := time.Now()

Expand Down Expand Up @@ -75,7 +74,7 @@ func TestSetActorDeleteSession(t *testing.T) {
}

// Check that actor exists in the session
session, err := sessionStore.Get(authedReq, cookieName)
session, err := newSessionStore().Get(authedReq, cookieName)
if err != nil {
t.Fatalf("didn't find session: %v", err)
}
Expand Down Expand Up @@ -148,8 +147,7 @@ func checkCookieDeleted(t *testing.T, resp *http.Response) {
func TestSessionExpiry(t *testing.T) {
logger := logtest.Scoped(t)

cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

userCreatedAt := time.Now()

Expand Down Expand Up @@ -195,8 +193,7 @@ func TestSessionExpiry(t *testing.T) {
func TestManualSessionExpiry(t *testing.T) {
logger := logtest.Scoped(t)

cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

user := &types.User{ID: 123, InvalidatedSessionsAt: time.Now()}
users := dbmocks.NewStrictMockUserStore()
Expand Down Expand Up @@ -237,8 +234,7 @@ func TestManualSessionExpiry(t *testing.T) {
}

func TestCookieMiddleware(t *testing.T) {
cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

actors := []*actor.Actor{{UID: 123, FromSessionCookie: true}, {UID: 456}, {UID: 789}}
userCreatedAt := time.Now()
Expand Down Expand Up @@ -325,8 +321,7 @@ func sessionCookie(r *http.Request) string {

func TestRecoverFromInvalidCookieValue(t *testing.T) {
logger := logtest.Scoped(t)
cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

// An actual cookie value that is an encoded JWT set by our old github.com/crewjam/saml-based
// SAML impl.
Expand Down Expand Up @@ -369,8 +364,7 @@ func TestRecoverFromInvalidCookieValue(t *testing.T) {
func TestMismatchedUserCreationFails(t *testing.T) {
logger := logtest.Scoped(t)

cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

// The user's creation date is fixed in the database, which
// will be reflected in the session store after an authenticated
Expand Down Expand Up @@ -431,8 +425,7 @@ func TestMismatchedUserCreationFails(t *testing.T) {
func TestOldUserSessionSucceeds(t *testing.T) {
logger := logtest.Scoped(t)

cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

// This user's session will _not_ have the UserCreatedAt value in the session
// store. When that situation occurs, we want to allow the session to continue
Expand Down Expand Up @@ -490,8 +483,7 @@ func TestExpiredLicenseOnlyAllowsAdmins(t *testing.T) {

logger := logtest.Scoped(t)

cleanup := ResetMockSessionStore(t)
defer cleanup()
ResetMockSessionStore(t)

userCreatedAt := time.Now()

Expand Down Expand Up @@ -569,8 +561,7 @@ func TestExpiredLicenseOnlyAllowsAdmins(t *testing.T) {
}

func TestSetActorFromUser(t *testing.T) {
cleanup := ResetMockSessionStore(t)
t.Cleanup(cleanup)
ResetMockSessionStore(t)

user := &types.User{
ID: 1,
Expand Down
10 changes: 5 additions & 5 deletions cmd/frontend/internal/auth/session/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import (
"github.com/gorilla/sessions"
)

func ResetMockSessionStore(t *testing.T) (cleanup func()) {
func ResetMockSessionStore(t *testing.T) {
var err error
tempdir, err := os.MkdirTemp("", "sourcegraph-oidc-test")
if err != nil {
return func() {}
t.Fatal(err)
}

defer func() {
Expand All @@ -21,8 +21,8 @@ func ResetMockSessionStore(t *testing.T) (cleanup func()) {
}
}()

SetSessionStore(sessions.NewFilesystemStore(tempdir, securecookie.GenerateRandomKey(2048)))
return func() {
mockSessionStore = sessions.NewFilesystemStore(tempdir, securecookie.GenerateRandomKey(2048))
t.Cleanup(func() {
os.RemoveAll(tempdir)
}
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ func newMockDBAndRequester() mockDetails {
}

func TestMiddleware(t *testing.T) {
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
session.ResetMockSessionStore(t)

const testCode = "testCode"
providerConfig := cloud.SchemaAuthProviderSourcegraphOperator{
Expand Down
Loading

0 comments on commit 7a7c663

Please sign in to comment.