Skip to content

Commit

Permalink
refactor: move hooks from mfa.go to hooks.go (supabase#1373)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Moves the existing `hooks.go` logic to the `hooks` package.

Co-authored-by: joel@joellee.org <joel@joellee.org>
  • Loading branch information
J0 and joel@joellee.org authored Jan 12, 2024
1 parent 41aac69 commit 2d64099
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 138 deletions.
138 changes: 138 additions & 0 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
Expand All @@ -17,6 +18,7 @@ import (
jwt "github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/hooks"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
Expand Down Expand Up @@ -287,3 +289,139 @@ type connectionWatcher struct {
func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) {
c.gotConn = true
}

func (a *API) runHook(ctx context.Context, name string, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)

request, err := json.Marshal(input)
if err != nil {
panic(err)
}

var response []byte
if err := db.Transaction(func(tx *storage.Connection) error {
// We rely on Postgres timeouts to ensure the function doesn't overrun
if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil {
return terr
}

if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", name), request).First(&response); terr != nil {
return terr
}

// reset the timeout
if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil {
return terr
}

return nil
}); err != nil {
return nil, err
}

if err := json.Unmarshal(response, output); err != nil {
return response, err
}

return response, nil
}

func (a *API) invokeHook(ctx context.Context, input, output any) error {
config := a.config
switch input.(type) {
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}

if _, err := a.runHook(ctx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}

return nil
case *hooks.PasswordVerificationAttemptInput:
hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if _, err := a.runHook(ctx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}

return nil
case *hooks.CustomAccessTokenInput:
hookOutput, ok := output.(*hooks.CustomAccessTokenOutput)
if !ok {
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runHook(ctx, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}
if err := validateTokenClaims(hookOutput.Claims); err != nil {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: err.Error(),
}

return httpError
}
return nil

default:
panic("unknown hook input type")
}
}
138 changes: 0 additions & 138 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"

Expand Down Expand Up @@ -198,142 +196,6 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
})
}

func (a *API) runHook(ctx context.Context, name string, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)

request, err := json.Marshal(input)
if err != nil {
panic(err)
}

var response []byte
if err := db.Transaction(func(tx *storage.Connection) error {
// We rely on Postgres timeouts to ensure the function doesn't overrun
if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil {
return terr
}

if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", name), request).First(&response); terr != nil {
return terr
}

// reset the timeout
if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil {
return terr
}

return nil
}); err != nil {
return nil, err
}

if err := json.Unmarshal(response, output); err != nil {
return response, err
}

return response, nil
}

func (a *API) invokeHook(ctx context.Context, input, output any) error {
config := a.config
switch input.(type) {
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}

if _, err := a.runHook(ctx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}

return nil
case *hooks.PasswordVerificationAttemptInput:
hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if _, err := a.runHook(ctx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}

return nil
case *hooks.CustomAccessTokenInput:
hookOutput, ok := output.(*hooks.CustomAccessTokenOutput)
if !ok {
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runHook(ctx, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}
if err := validateTokenClaims(hookOutput.Claims); err != nil {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: err.Error(),
}

return httpError
}
return nil

default:
panic("unknown hook input type")
}
}

func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
var err error
ctx := r.Context()
Expand Down

0 comments on commit 2d64099

Please sign in to comment.