Skip to content

Commit

Permalink
commit 21
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnSPG committed Dec 16, 2023
1 parent d0b06aa commit ceee46d
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 12 deletions.
3 changes: 3 additions & 0 deletions app/services/sales-api/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"

testgrp "github.com/shawnzxx/service/app/services/sales-api/handlers/v1"
"github.com/shawnzxx/service/business/web/auth"
"github.com/shawnzxx/service/business/web/v1/mid"
"github.com/shawnzxx/service/foundation/web"
"go.uber.org/zap"
Expand All @@ -14,6 +15,7 @@ import (
type APIMuxConfig struct {
Shutdown chan os.Signal
Log *zap.SugaredLogger
Auth *auth.Auth
}

// APIMux constructs a http.Handler with all application routes defined.
Expand All @@ -23,6 +25,7 @@ func APIMux(cfg APIMuxConfig) *web.App {
app := web.NewApp(cfg.Shutdown, mid.Logger(cfg.Log), mid.Errors(cfg.Log), mid.Metrics(), mid.Panics())

app.Handle(http.MethodGet, "/test", testgrp.Test)
app.Handle(http.MethodGet, "/test/auth", testgrp.Test, mid.Authenticate(cfg.Auth), mid.Authorize(cfg.Auth, auth.RuleAdminOnly))

return app
}
29 changes: 29 additions & 0 deletions app/services/sales-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (

"github.com/ardanlabs/conf/v3"
"github.com/shawnzxx/service/app/services/sales-api/handlers"
"github.com/shawnzxx/service/business/web/auth"
"github.com/shawnzxx/service/business/web/v1/debug"
"github.com/shawnzxx/service/foundation/keystore"
"github.com/shawnzxx/service/foundation/logger"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -54,6 +56,11 @@ func run(log *zap.SugaredLogger) error {
APIHost string `conf:"default:0.0.0.0:3000"`
DebugHost string `conf:"default:0.0.0.0:4000"`
}
Auth struct {
KeysFolder string `conf:"default:zarf/keys/"`
ActiveKID string `conf:"default:54bb2165-71e1-41a6-af3e-7da4a0e1e2c1"`
Issuer string `conf:"default:service project"`
}
}{
Version: conf.Version{
Build: build,
Expand Down Expand Up @@ -83,6 +90,27 @@ func run(log *zap.SugaredLogger) error {
}
log.Infow("startup", "config", out)

// -------------------------------------------------------------------------
// Initialize authentication support

log.Infow("startup", "status", "initializing authentication support")

// Simple keystore versus using Vault.
ks, err := keystore.NewFS(os.DirFS(cfg.Auth.KeysFolder))
if err != nil {
return fmt.Errorf("reading keys: %w", err)
}

authCfg := auth.Config{
Log: log,
KeyLookup: ks,
}

authCong, err := auth.New(authCfg)
if err != nil {
return fmt.Errorf("constructing authCong: %w", err)
}

// -------------------------------------------------------------------------
// Start Debug Service

Expand All @@ -105,6 +133,7 @@ func run(log *zap.SugaredLogger) error {
apiMux := handlers.APIMux(handlers.APIMuxConfig{
Shutdown: shutdown,
Log: log,
Auth: authCong,
})

server := http.Server{
Expand Down
7 changes: 4 additions & 3 deletions business/core/user/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type Role struct {
}

// ParseRole parses the string value and returns a role if one exists.
// used by application layer, should not use by business layer, business layer should use Role type directly
func ParseRole(value string) (Role, error) {
role, exists := roles[value]
if !exists {
Expand All @@ -30,8 +31,8 @@ func ParseRole(value string) (Role, error) {
}

// MustParseRole parses the string value and returns a role if one exists.
// this function is only used for unit test
// If an error occurs the function panics.
// nobody should use Myst function in application layer, this function is only used for unit test
// because we have panic in this function, so we can not use this function in any layer except test
func MustParseRole(value string) Role {
role, err := ParseRole(value)
if err != nil {
Expand All @@ -57,7 +58,7 @@ func (r Role) MarshalText() ([]byte, error) {
return []byte(r.name), nil
}

// Equal provides support for the go-cmp package and testing.
// Equal provides support for the go-compare package use for testing.
func (r Role) Equal(r2 Role) bool {
return r.name == r2.name
}
24 changes: 16 additions & 8 deletions business/web/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ type Claims struct {
Roles []user.Role `json:"roles"`
}

// KeyLookup declares a method set of behavior for looking up
// private and public keys for JWT use. Since we are not sure is PEM format or other JWT format,
// here we named is key string, generic for return could be a PEM encoded string or a JWS based key.
// KeyLookup declares a method set of behavior for looking up private and public keys for JWT use.
// here we named is as key instead of pem, why? because we are using OPA which needs pem format for the public key,
// if in future we are not use OPA, for example jwt format, then we can rename it
type KeyLookup interface {
PrivateKey(kid string) (key string, err error)
PublicKey(kid string) (key string, err error)
Expand Down Expand Up @@ -68,19 +68,23 @@ func New(cfg Config) (*Auth, error) {

// GenerateToken generates a signed JWT token string representing the user Claims.
func (a *Auth) GenerateToken(kid string, claims Claims) (string, error) {
// if you want to generate token, you give kid you give claims
token := jwt.NewWithClaims(a.method, claims)
// put kid in to the token
token.Header["kid"] = kid

// do the private key lookup from that kid, may be background hit the Vault or something else, then get the private key
privateKeyPEM, err := a.keyLookup.PrivateKey(kid)
if err != nil {
return "", fmt.Errorf("private key: %w", err)
}

// convert to pem format
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM))
if err != nil {
return "", fmt.Errorf("parsing private pem: %w", err)
}

// sign the token
str, err := token.SignedString(privateKey)
if err != nil {
return "", fmt.Errorf("signing token: %w", err)
Expand All @@ -89,21 +93,22 @@ func (a *Auth) GenerateToken(kid string, claims Claims) (string, error) {
return str, nil
}

// Authenticate processes the token to validate the sender's token is valid.
// Authenticate processes to validate the token's signature.
func (a *Auth) Authenticate(ctx context.Context, bearerToken string) (Claims, error) {
parts := strings.Split(bearerToken, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return Claims{}, errors.New("expected authorization header format: Bearer <token>")
}

var claims Claims
// parsing token back to claim, here we don't worry token if is valid ot not
// for valid token we use OPA in later steps.
token, _, err := a.parser.ParseUnverified(parts[1], &claims)
if err != nil {
return Claims{}, fmt.Errorf("error parsing token: %w", err)
}

// Perform an extra level of authentication verification with OPA.

kidRaw, exists := token.Header["kid"]
if !exists {
return Claims{}, fmt.Errorf("kid missing from header: %w", err)
Expand All @@ -114,11 +119,13 @@ func (a *Auth) Authenticate(ctx context.Context, bearerToken string) (Claims, er
return Claims{}, fmt.Errorf("kid malformed: %w", err)
}

// use keyStore find back the public key
pem, err := a.publicKeyLookup(kid)
if err != nil {
return Claims{}, fmt.Errorf("failed to fetch public key: %w", err)
}

// prepare input struct for opa to validate the token
input := map[string]any{
"Key": pem,
"Token": parts[1],
Expand All @@ -129,8 +136,9 @@ func (a *Auth) Authenticate(ctx context.Context, bearerToken string) (Claims, er
return Claims{}, fmt.Errorf("authentication failed : %w", err)
}

// Check the database for this user to verify they are still enabled.

// TODO additional check for authentication
// check our the database for this user still enabled.

return claims, nil
}

Expand Down
27 changes: 27 additions & 0 deletions business/web/auth/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package auth

import (
"context"
)

// ctxKey represents the type of value for the context key.
type ctxKey int

// key is used to store/retrieve a Claims value from a context.Context.
const claimKey ctxKey = 1

// =============================================================================

// SetClaims stores the claims in the context.
func SetClaims(ctx context.Context, claims Claims) context.Context {
return context.WithValue(ctx, claimKey, claims)
}

// GetClaims returns the claims from the context.
func GetClaims(ctx context.Context) Claims {
v, ok := ctx.Value(claimKey).(Claims)
if !ok {
return Claims{}
}
return v
}
31 changes: 31 additions & 0 deletions business/web/auth/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package auth

import (
"errors"
"fmt"
)

// AuthError is used to pass an error during the request through the
// application with auth specific context.
type AuthError struct {
msg string
}

// NewAuthError creates an AuthError for the provided message.
func NewAuthError(format string, args ...any) error {
return &AuthError{
msg: fmt.Sprintf(format, args...),
}
}

// Error implements the error interface. It uses the default message of the
// wrapped error. This is what will be shown in the services' logs.
func (ae *AuthError) Error() string {
return ae.msg
}

// IsAuthError checks if an error of type AuthError exists.
func IsAuthError(err error) bool {
var ae *AuthError
return errors.As(err, &ae)
}
52 changes: 52 additions & 0 deletions business/web/v1/mid/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package mid

import (
"context"
"net/http"

"github.com/shawnzxx/service/business/web/auth"
"github.com/shawnzxx/service/foundation/web"
)

// Authenticate validates a JWT from the `Authorization` header.
func Authenticate(a *auth.Auth) web.Middleware {
m := func(handler web.Handler) web.Handler {
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
claims, err := a.Authenticate(ctx, r.Header.Get("authorization"))
if err != nil {
return auth.NewAuthError("authenticate: failed: %s", err)
}

ctx = auth.SetClaims(ctx, claims)

return handler(ctx, w, r)
}

return h
}

return m
}

// Authorize validates that an authenticated user has at least one role from a
// specified list. This method constructs the actual function that is used.
func Authorize(a *auth.Auth, rule string) web.Middleware {
m := func(handler web.Handler) web.Handler {
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
claims := auth.GetClaims(ctx)
if claims.Subject == "" {
return auth.NewAuthError("authorize: you are not authorized for that action, no claims")
}

if err := a.Authorize(ctx, claims, rule); err != nil {
return auth.NewAuthError("authorize: you are not authorized for that action, claims[%v] rule[%v]: %s", claims.Roles, rule, err)
}

return handler(ctx, w, r)
}

return h
}

return m
}
7 changes: 7 additions & 0 deletions business/web/v1/mid/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"

"github.com/shawnzxx/service/business/web/auth"
v1 "github.com/shawnzxx/service/business/web/v1"
"github.com/shawnzxx/service/foundation/web"
"go.uber.org/zap"
Expand All @@ -29,6 +30,12 @@ func Errors(log *zap.SugaredLogger) web.Middleware {
}
status = reqErr.Status

case auth.IsAuthError(err):
er = v1.ErrorResponse{
Error: http.StatusText(http.StatusUnauthorized),
}
status = http.StatusUnauthorized

default:
er = v1.ErrorResponse{
Error: http.StatusText(http.StatusInternalServerError),
Expand Down
2 changes: 1 addition & 1 deletion business/web/v1/mid/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"go.uber.org/zap"
)

// this logger at business layer, becuase as top app layer's middleware it can log request and response in general
// Logger at business layer, becuase as top app layer's middleware it can log request and response in general
// later business layer logic also can use this logger to log business realted logs
func Logger(log *zap.SugaredLogger) web.Middleware {
m := func(handler web.Handler) web.Handler {
Expand Down
Loading

0 comments on commit ceee46d

Please sign in to comment.