Skip to content

Commit

Permalink
fix: allow gotrue to work with multiple custom domains (#999)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* Improves on #725, albeit with a slightly different approach
* Gotrue will accept an allow list of domains via a comma-separate
string (`DOMAIN_ALLOW_LIST`) , which includes the `API_EXTERNAL_URL` by
default. On each request, gotrue will check that the domain being used
is also included in the allow list.
* When gotrue starts up, it will take the `DOMAIN_ALLOW_LIST` and
convert it into a map where the key is the hostname and the value is the
url
* When a request is made to gotrue, gotrue will check the
`DomainAllowListMap` to check if there is a matching hostname before
allowing the request through. If there isn't a matching hostname used,
gotrue will default to use the `API_EXTERNAL_URL` instead.
* This helps to make gotrue usable with multiple custom domains, and
also allows the email links to contain the custom domain.
* Since the `EXTERNAL_XXX_REDIRECT_URI` is derived during runtime, we
can remove that config once this PR is merged in as long as the
`REDIRECT_URI` is also included in the `DOMAIN_ALLOW_LIST`

---------

Co-authored-by: Joel Lee <lee.yi.jie.joel@gmail.com>
  • Loading branch information
kangmingtay and J0 authored May 12, 2023
1 parent bafb89b commit 91a82ed
Show file tree
Hide file tree
Showing 18 changed files with 277 additions and 102 deletions.
3 changes: 2 additions & 1 deletion internal/api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() {
req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))

*ts.Config = *c.customConfig
ts.Config.JWT = c.customConfig.JWT
ts.Config.External = c.customConfig.External
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expected, w.Code)
})
Expand Down
2 changes: 2 additions & 0 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati

r.Route("/callback", func(r *router) {
r.UseBypass(logger)
r.Use(api.isValidExternalHost)
r.Use(api.loadFlowState)

r.Get("/", api.ExternalProviderCallback)
Expand All @@ -97,6 +98,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati

r.Route("/", func(r *router) {
r.UseBypass(logger)
r.Use(api.isValidExternalHost)

r.Get("/settings", api.Settings)

Expand Down
14 changes: 14 additions & 0 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"net/url"

jwt "github.com/golang-jwt/jwt"
"github.com/supabase/gotrue/internal/models"
Expand All @@ -28,6 +29,7 @@ const (
oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token
oauthVerifierKey = contextKey("oauth_verifier")
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
)

Expand Down Expand Up @@ -235,3 +237,15 @@ func getSSOProvider(ctx context.Context) *models.SSOProvider {
}
return obj.(*models.SSOProvider)
}

func withExternalHost(ctx context.Context, u *url.URL) context.Context {
return context.WithValue(ctx, externalHostKey, u)
}

func getExternalHost(ctx context.Context) *url.URL {
obj := ctx.Value(externalHostKey)
if obj == nil {
return nil
}
return obj.(*url.URL)
}
21 changes: 20 additions & 1 deletion internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
if !emailData.Verified && !config.Mailer.Autoconfirm {
mailer := a.Mailer(ctx)
referrer := a.getReferrer(r)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil {
externalURL := getExternalHost(ctx)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil {
if errors.Is(terr, MaxFrequencyLimitError) {
return nil, tooManyRequestsError("For security purposes, you can only request this once every minute")
}
Expand Down Expand Up @@ -510,43 +511,61 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont
func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) {
config := a.config
name = strings.ToLower(name)
callbackURL := getExternalHost(ctx).String() + "/callback"

switch name {
case "apple":
config.External.Apple.RedirectURI = callbackURL
return provider.NewAppleProvider(config.External.Apple)
case "azure":
config.External.Azure.RedirectURI = callbackURL
return provider.NewAzureProvider(config.External.Azure, scopes)
case "bitbucket":
config.External.Bitbucket.RedirectURI = callbackURL
return provider.NewBitbucketProvider(config.External.Bitbucket)
case "discord":
config.External.Discord.RedirectURI = callbackURL
return provider.NewDiscordProvider(config.External.Discord, scopes)
case "github":
config.External.Github.RedirectURI = callbackURL
return provider.NewGithubProvider(config.External.Github, scopes)
case "gitlab":
config.External.Gitlab.RedirectURI = callbackURL
return provider.NewGitlabProvider(config.External.Gitlab, scopes)
case "google":
config.External.Google.RedirectURI = callbackURL
return provider.NewGoogleProvider(config.External.Google, scopes)
case "kakao":
return provider.NewKakaoProvider(config.External.Kakao, scopes)
case "keycloak":
config.External.Keycloak.RedirectURI = callbackURL
return provider.NewKeycloakProvider(config.External.Keycloak, scopes)
case "linkedin":
config.External.Linkedin.RedirectURI = callbackURL
return provider.NewLinkedinProvider(config.External.Linkedin, scopes)
case "facebook":
config.External.Facebook.RedirectURI = callbackURL
return provider.NewFacebookProvider(config.External.Facebook, scopes)
case "notion":
config.External.Notion.RedirectURI = callbackURL
return provider.NewNotionProvider(config.External.Notion)
case "spotify":
config.External.Spotify.RedirectURI = callbackURL
return provider.NewSpotifyProvider(config.External.Spotify, scopes)
case "slack":
config.External.Slack.RedirectURI = callbackURL
return provider.NewSlackProvider(config.External.Slack, scopes)
case "twitch":
config.External.Twitch.RedirectURI = callbackURL
return provider.NewTwitchProvider(config.External.Twitch, scopes)
case "twitter":
config.External.Twitter.RedirectURI = callbackURL
return provider.NewTwitterProvider(config.External.Twitter, scopes)
case "workos":
config.External.WorkOS.RedirectURI = callbackURL
return provider.NewWorkOSProvider(config.External.WorkOS)
case "zoom":
config.External.Zoom.RedirectURI = callbackURL
return provider.NewZoomProvider(config.External.Zoom)
default:
return nil, fmt.Errorf("Provider %s could not be found", name)
Expand Down
3 changes: 2 additions & 1 deletion internal/api/invite.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {

mailer := a.Mailer(ctx)
referrer := a.getReferrer(r)
if err := sendInvite(tx, user, mailer, referrer, config.Mailer.OtpLength); err != nil {
externalURL := getExternalHost(ctx)
if err := sendInvite(tx, user, mailer, referrer, externalURL, config.Mailer.OtpLength); err != nil {
return internalServerError("Error inviting user").WithInternalError(err)
}
return nil
Expand Down
3 changes: 2 additions & 1 deletion internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {

mailer := a.Mailer(ctx)
referrer := a.getReferrer(r)
return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, flowType)
externalURL := getExternalHost(ctx)
return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType)
})
if err != nil {
if errors.Is(err, MaxFrequencyLimitError) {
Expand Down
24 changes: 13 additions & 11 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -205,7 +206,8 @@ func (a *API) GenerateLink(w http.ResponseWriter, r *http.Request) error {
return terr
}

url, terr = mailer.GetEmailActionLink(user, params.Type, referrer)
externalURL := getExternalHost(ctx)
url, terr = mailer.GetEmailActionLink(user, params.Type, referrer, externalURL)
if terr != nil {
return terr
}
Expand All @@ -228,7 +230,7 @@ func (a *API) GenerateLink(w http.ResponseWriter, r *http.Request) error {
return sendJSON(w, http.StatusOK, resp)
}

func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error {
func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
var err error
if u.ConfirmationSentAt != nil && !u.ConfirmationSentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
Expand All @@ -241,15 +243,15 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail
token := fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp)))
u.ConfirmationToken = addFlowPrefixToToken(token, flowType)
now := time.Now()
if err := mailer.ConfirmationMail(u, otp, referrerURL); err != nil {
if err := mailer.ConfirmationMail(u, otp, referrerURL, externalURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending confirmation email")
}
u.ConfirmationSentAt = &now
return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation")
}

func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, otpLength int) error {
func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error {
var err error
oldToken := u.ConfirmationToken
otp, err := crypto.GenerateOtp(otpLength)
Expand All @@ -258,7 +260,7 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re
}
u.ConfirmationToken = fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp)))
now := time.Now()
if err := mailer.InviteMail(u, otp, referrerURL); err != nil {
if err := mailer.InviteMail(u, otp, referrerURL, externalURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending invite email")
}
Expand All @@ -267,7 +269,7 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re
return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite")
}

func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error {
func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
var err error
if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
Expand All @@ -281,7 +283,7 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile
token := fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp)))
u.RecoveryToken = addFlowPrefixToToken(token, flowType)
now := time.Now()
if err := mailer.RecoveryMail(u, otp, referrerURL); err != nil {
if err := mailer.RecoveryMail(u, otp, referrerURL, externalURL); err != nil {
u.RecoveryToken = oldToken
return errors.Wrap(err, "Error sending recovery email")
}
Expand Down Expand Up @@ -313,7 +315,7 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma
return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication")
}

func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error {
func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
var err error
// since Magic Link is just a recovery with a different template and behaviour
// around new users we will reuse the recovery db timer to prevent potential abuse
Expand All @@ -329,7 +331,7 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile
u.RecoveryToken = addFlowPrefixToToken(token, flowType)

now := time.Now()
if err := mailer.MagicLinkMail(u, otp, referrerURL); err != nil {
if err := mailer.MagicLinkMail(u, otp, referrerURL, externalURL); err != nil {
u.RecoveryToken = oldToken
return errors.Wrap(err, "Error sending magic link email")
}
Expand All @@ -338,7 +340,7 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile
}

// sendEmailChange sends out an email change token to the new email.
func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfiguration, u *models.User, mailer mailer.Mailer, email string, referrerURL string, otpLength int, flowType models.FlowType) error {
func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfiguration, u *models.User, mailer mailer.Mailer, email, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
var err error
if u.EmailChangeSentAt != nil && !u.EmailChangeSentAt.Add(config.SMTP.MaxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
Expand Down Expand Up @@ -366,7 +368,7 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu

u.EmailChangeConfirmStatus = zeroConfirmation
now := time.Now()
if err := mailer.EmailChangeMail(u, otpNew, otpCurrent, referrerURL); err != nil {
if err := mailer.EmailChangeMail(u, otpNew, otpCurrent, referrerURL, externalURL); err != nil {
return err
}

Expand Down
16 changes: 15 additions & 1 deletion internal/api/mail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -39,6 +40,11 @@ func (ts *MailTestSuite) SetupTest() {
models.TruncateAll(ts.API.db)

ts.Config.Mailer.SecureEmailChangeEnabled = true

// Create User
u, err := models.NewUser("12345678", "test@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating new user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new user")
}

func (ts *MailTestSuite) TestGenerateLink() {
Expand Down Expand Up @@ -108,11 +114,14 @@ func (ts *MailTestSuite) TestGenerateLink() {
},
}

customDomainUrl, err := url.ParseRequestURI("https://example.gotrue.com")
require.NoError(ts.T(), err)

for _, c := range cases {
ts.Run(c.Desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.Body))
req := httptest.NewRequest(http.MethodPost, "/admin/generate_link", &buffer)
req := httptest.NewRequest(http.MethodPost, customDomainUrl.String()+"/admin/generate_link", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
w := httptest.NewRecorder()

Expand All @@ -131,6 +140,11 @@ func (ts *MailTestSuite) TestGenerateLink() {

// check if hashed_token matches hash function of email and the raw otp
require.Equal(ts.T(), data["hashed_token"], fmt.Sprintf("%x", sha256.Sum224([]byte(c.Body.Email+data["email_otp"].(string)))))

// check if the host used in the email link matches the initial request host
u, err := url.ParseRequestURI(data["action_link"].(string))
require.NoError(ts.T(), err)
require.Equal(ts.T(), req.Host, u.Host)
})
}
}
28 changes: 28 additions & 0 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -180,6 +182,32 @@ func isIgnoreCaptchaRoute(req *http.Request) bool {
return false
}

func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) {
ctx := req.Context()
config := a.config

var u *url.URL
var err error

baseUrl := config.API.ExternalURL
xForwardedHost := req.Header.Get("X-Forwarded-Host")
xForwardedProto := req.Header.Get("X-Forwarded-Proto")
if xForwardedHost != "" && xForwardedProto != "" {
baseUrl = fmt.Sprintf("%s://%s", xForwardedProto, xForwardedHost)
} else if req.URL.Scheme != "" && req.URL.Hostname() != "" {
baseUrl = fmt.Sprintf("%s://%s", req.URL.Scheme, req.URL.Hostname())
}
if u, err = url.ParseRequestURI(baseUrl); err != nil {
// fallback to the default hostname
log := observability.GetLogEntry(req)
log.WithField("request_url", baseUrl).Warn(err)
if u, err = url.ParseRequestURI(config.API.ExternalURL); err != nil {
return ctx, err
}
}
return withExternalHost(ctx, u), nil
}

func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
ctx := req.Context()
if !a.config.SAML.Enabled {
Expand Down
30 changes: 30 additions & 0 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"

jwt "github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -229,6 +230,35 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
}
}

func (ts *MiddlewareTestSuite) TestIsValidExternalHost() {
cases := []struct {
desc string
requestURL string
expectedURL string
}{
{
desc: "Valid custom external url",
requestURL: "https://example.custom.com",
expectedURL: "https://example.custom.com",
},
}

_, err := url.ParseRequestURI("https://example.custom.com")
require.NoError(ts.T(), err)

for _, c := range cases {
ts.Run(c.desc, func() {
req := httptest.NewRequest(http.MethodPost, c.requestURL, nil)
w := httptest.NewRecorder()
ctx, err := ts.API.isValidExternalHost(w, req)
require.NoError(ts.T(), err)

externalURL := getExternalHost(ctx)
require.Equal(ts.T(), c.expectedURL, externalURL.String())
})
}
}

func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() {
cases := []struct {
desc string
Expand Down
Loading

0 comments on commit 91a82ed

Please sign in to comment.