diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index d45315f..207cef5 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -114,7 +114,7 @@ jobs: wait-for: 90s - run: go install github.com/fullstorydev/grpcurl/cmd/grpcurl@v1.8.9 - run: grpcurl -plaintext localhost:9000 list - - run: grpcurl -plaintext localhost:9000 list policy.attributes.AttributesService + - run: grpcurl -plaintext localhost:9000 grpc.health.v1.Health.Check image: name: image build diff --git a/docs/configuration.md b/docs/configuration.md index 75fe365..e1afb6f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -42,6 +42,11 @@ The server configuration is used to define how the application runs its server. - `enabled`: Enable tls. `(default: false)` - `cert`: The path to the tls certificate. - `key`: The path to the tls key. +- `auth`: The configuration for your trusted IDP. + - `enabled`: Enable authentication. `(default: true)` + - `audience`: The audience for the IDP. + - `issuer`: The issuer for the IDP. + - `clients`: A list of client id's that are allowed Example: @@ -56,6 +61,13 @@ server: enabled: true cert: /path/to/cert key: /path/to/key + auth: + enabled: true + audience: https://example.com + issuer: https://example.com + clients: + - client_id + - client_id2 ``` ## Database Configuration diff --git a/example-opentdf.yaml b/example-opentdf.yaml index e4ef2d9..a6b1e38 100644 --- a/example-opentdf.yaml +++ b/example-opentdf.yaml @@ -38,6 +38,12 @@ services: - "msExchMailboxGuid" - "msExchMailboxSecurityDescriptor" server: + auth: + enabled: false + audience: "opentdf" + issuer: http://localhost:8888/auth/realms/opentdf + clients: + - "opentdf" grpc: port: 9000 reflectionEnabled: true # Default is false diff --git a/internal/auth/authn.go b/internal/auth/authn.go new file mode 100644 index 0000000..e579b6c --- /dev/null +++ b/internal/auth/authn.go @@ -0,0 +1,239 @@ +package auth + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "slices" + "strings" + "time" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +var ( + // Set of allowed gRPC endpoints that do not require authentication + allowedGRPCEndpoints = [...]string{ + "/grpc.health.v1.Health/Check", + "/wellknownconfiguration.WellKnownService/GetWellKnownConfiguration", + } + // Set of allowed HTTP endpoints that do not require authentication + allowedHTTPEndpoints = [...]string{ + "/healthz", + "/.well-known/opentdf-configuration", + } +) + +// Authentication holds a jwks cache and information about the openid configuration +type authentication struct { + // cache holds the jwks cache + cache *jwk.Cache + // openidConfigurations holds the openid configuration for each issuer + oidcConfigurations map[string]AuthNConfig +} + +// Creates new authN which is used to verify tokens for a set of given issuers +func NewAuthenticator(cfg AuthNConfig) (*authentication, error) { + a := &authentication{} + a.oidcConfigurations = make(map[string]AuthNConfig) + + ctx := context.Background() + + a.cache = jwk.NewCache(ctx) + + // Build new cache + // Discover OIDC Configuration + oidcConfig, err := DiscoverOIDCConfiguration(ctx, cfg.Issuer) + if err != nil { + return nil, err + } + + cfg.OIDCConfiguration = *oidcConfig + + // Register the jwks_uri with the cache + if err := a.cache.Register(cfg.JwksURI, jwk.WithMinRefreshInterval(15*time.Minute)); err != nil { + return nil, err + } + + // Need to refresh the cache to verify jwks is available + _, err = a.cache.Refresh(ctx, cfg.JwksURI) + if err != nil { + return nil, err + } + + a.oidcConfigurations[cfg.Issuer] = cfg + + return a, nil +} + +// verifyTokenHandler is a http handler that verifies the token +func (a authentication) VerifyTokenHandler(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if slices.Contains(allowedHTTPEndpoints[:], r.URL.Path) { + handler.ServeHTTP(w, r) + return + } + // Verify the token + header := r.Header["Authorization"] + if len(header) < 1 { + http.Error(w, "missing authorization header", http.StatusUnauthorized) + return + } + err := checkToken(r.Context(), header, a) + if err != nil { + slog.WarnContext(r.Context(), "failed to validate token", slog.String("error", err.Error())) + http.Error(w, "unauthenticated", http.StatusUnauthorized) + return + } + + handler.ServeHTTP(w, r) + }) +} + +// verifyTokenInterceptor is a grpc interceptor that verifies the token in the metadata +func (a authentication) VerifyTokenInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + // Allow health checks to pass through + if slices.Contains(allowedGRPCEndpoints[:], info.FullMethod) { + return handler(ctx, req) + } + + // Get the metadata from the context + // The keys within metadata.MD are normalized to lowercase. + // See: https://godoc.org/google.golang.org/grpc/metadata#New + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "missing metadata") + } + + // Verify the token + header := md["authorization"] + if len(header) < 1 { + return nil, status.Error(codes.Unauthenticated, "missing authorization header") + } + + err := checkToken(ctx, header, a) + if err != nil { + slog.Warn("failed to validate token", slog.String("error", err.Error())) + return nil, status.Errorf(codes.Unauthenticated, "unauthenticated") + } + + return handler(ctx, req) +} + +// checkToken is a helper function to verify the token. +func checkToken(ctx context.Context, authHeader []string, auth authentication) error { + var ( + tokenRaw string + tokenType string + ) + + // If we don't get a DPoP/Bearer token type, we can't proceed + switch { + case strings.HasPrefix(authHeader[0], "DPoP "): + tokenType = "DPoP" + tokenRaw = strings.TrimPrefix(authHeader[0], "DPoP ") + case strings.HasPrefix(authHeader[0], "Bearer "): + tokenType = "Bearer" + tokenRaw = strings.TrimPrefix(authHeader[0], "Bearer ") + default: + return fmt.Errorf("not of type bearer or dpop") + } + + // Future work is to validate DPoP proof if token type is DPoP + //nolint:staticcheck + if tokenType == "DPoP" { + // Implement in the future here or as separate interceptor + } + + // We have to get iss from the token first to verify the signature + unverifiedToken, err := jwt.Parse([]byte(tokenRaw), jwt.WithVerify(false)) + if err != nil { + return err + } + + // Get issuer from unverified token + issuer, exists := unverifiedToken.Get("iss") + if !exists { + return fmt.Errorf("missing issuer") + } + + // Get the openid configuration for the issuer + // Because we get an interface we need to cast it to a string + // and jwx expects it as a string so we should never hit this error if the token is valid + issuerStr, ok := issuer.(string) + if !ok { + return fmt.Errorf("invalid issuer") + } + oidc, exists := auth.oidcConfigurations[issuerStr] + if !exists { + return fmt.Errorf("invalid issuer") + } + + // Get key set from cache that matches the jwks_uri + keySet, err := auth.cache.Get(ctx, oidc.JwksURI) + if err != nil { + return fmt.Errorf("failed to get jwks from cache") + } + + // Now we verify the token signature + _, err = jwt.Parse([]byte(tokenRaw), + jwt.WithKeySet(keySet), + jwt.WithValidate(true), + jwt.WithIssuer(issuerStr), + jwt.WithAudience(oidc.Audience), + jwt.WithValidator(jwt.ValidatorFunc(auth.claimsValidator)), + ) + if err != nil { + return err + } + + return nil +} + +// claimsValidator is a custom validator to check extra claims in the token. +// right now it only checks for client_id +func (a authentication) claimsValidator(ctx context.Context, token jwt.Token) jwt.ValidationError { + var ( + clientID string + ) + + // Need to check for cid and client_id as this claim seems to be different between idp's + cidClaim, cidExists := token.Get("cid") + clientIDClaim, clientIDExists := token.Get("client_id") + + // Check to see if we have a client id claim + switch { + case cidExists: + if cid, ok := cidClaim.(string); ok { + clientID = cid + break + } + case clientIDExists: + if cid, ok := clientIDClaim.(string); ok { + clientID = cid + break + } + default: + return jwt.NewValidationError(fmt.Errorf("client id required")) + } + + // Check if the client id is allowed in list of clients + foundClientID := false + for _, c := range a.oidcConfigurations[token.Issuer()].Clients { + if c == clientID { + foundClientID = true + break + } + } + if !foundClientID { + return jwt.NewValidationError(fmt.Errorf("invalid client id")) + } + + return nil +} diff --git a/internal/auth/authn_test.go b/internal/auth/authn_test.go new file mode 100644 index 0000000..31366d7 --- /dev/null +++ b/internal/auth/authn_test.go @@ -0,0 +1,298 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +type AuthSuite struct { + suite.Suite + server *httptest.Server + key jwk.Key + auth *authentication +} + +func (s *AuthSuite) SetupTest() { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + slog.Error("failed to generate RSA private key", slog.String("error", err.Error())) + return + } + + pubKeyJWK, err := jwk.FromRaw(privKey.PublicKey) + if err != nil { + slog.Error("failed to create jwk.Key from RSA public key", slog.String("error", err.Error())) + return + } + pubKeyJWK.Set(jws.KeyIDKey, "test") + pubKeyJWK.Set(jwk.AlgorithmKey, jwa.RS256) + + // Create a new set with rsa public key + set := jwk.NewSet() + if err := set.AddKey(pubKeyJWK); err != nil { + slog.Error("failed to add RSA public key to jwk.Set", slog.String("error", err.Error())) + return + } + + key, err := jwk.FromRaw(privKey) + if err != nil { + slog.Error("failed to create jwk.Key from RSA private key", slog.String("error", err.Error())) + return + } + key.Set(jws.KeyIDKey, "test") + + s.key = key + + s.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if r.URL.Path == "/.well-known/openid-configuration" { + w.Write([]byte(fmt.Sprintf(`{"jwks_uri": "%s/jwks"}`, s.server.URL))) + return + } + if r.URL.Path == "/jwks" { + json.NewEncoder(w).Encode(set) + return + } + })) + + auth, err := NewAuthenticator(AuthNConfig{ + Issuer: s.server.URL, + Audience: "test", + Clients: []string{"client1", "client2", "client3"}, + }) + + assert.Nil(s.T(), err) + + s.auth = auth +} + +func (s *AuthSuite) TearDownTest() { + s.server.Close() +} + +func TestAuthSuite(t *testing.T) { + suite.Run(t, new(AuthSuite)) +} + +func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)) + + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "\"exp\" not satisfied", err.Error()) +} + +func (s *AuthSuite) Test_VerifyTokenHandler_When_Authorization_Header_Missing_Expect_Error() { + handler := s.auth.VerifyTokenHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + assert.Equal(s.T(), http.StatusUnauthorized, rec.Code) + assert.Equal(s.T(), "missing authorization header\n", rec.Body.String()) +} + +func (s *AuthSuite) Test_VerifyTokenInterceptor_When_Authorization_Header_Missing_Expect_Error() { + md := metadata.New(map[string]string{}) + ctx := metadata.NewIncomingContext(context.Background(), md) + _, err := s.auth.VerifyTokenInterceptor(ctx, "test", &grpc.UnaryServerInfo{ + FullMethod: "/test", + }, nil) + assert.NotNil(s.T(), err) + assert.ErrorIs(s.T(), err, status.Error(codes.Unauthenticated, "missing authorization header")) +} + +func (s *AuthSuite) Test_CheckToken_When_Authorization_Header_Invalid_Expect_Error() { + err := checkToken(context.Background(), []string{"BPOP "}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "not of type bearer or dpop", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_Missing_Issuer_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "missing issuer", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_Invalid_Issuer_Value_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", "invalid") + + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "invalid issuer", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_Invalid_Issuer_INT_Value_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", 1) + + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "missing issuer", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_Audience_Missing_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "claim \"aud\" not found", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_Audience_Invalid_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + tok.Set("aud", "invalid") + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "\"aud\" not satisfied", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_ClientID_Missing_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + tok.Set("aud", "test") + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "client id required", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_ClientID_Invalid_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + tok.Set("aud", "test") + tok.Set("client_id", "invalid") + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "invalid client id", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_CID_Invalid_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + tok.Set("aud", "test") + tok.Set("cid", "invalid") + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "invalid client id", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_CID_Invalid_INT_Expect_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + tok.Set("aud", "test") + tok.Set("cid", 1) + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.NotNil(s.T(), err) + assert.Equal(s.T(), "invalid client id", err.Error()) +} + +func (s *AuthSuite) Test_CheckToken_When_Valid_Expect_No_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + tok.Set("aud", "test") + tok.Set("client_id", "client1") + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, *s.auth) + assert.Nil(s.T(), err) +} + +func (s *AuthSuite) Test_CheckToken_When_Valid_CID_Expect_No_Error() { + tok := jwt.New() + tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set("iss", s.server.URL) + tok.Set("aud", "test") + tok.Set("cid", "client2") + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + + assert.NotNil(s.T(), signedTok) + assert.Nil(s.T(), err) + + err = checkToken(context.Background(), []string{fmt.Sprintf("DPoP %s", string(signedTok))}, *s.auth) + assert.Nil(s.T(), err) +} diff --git a/internal/auth/config.go b/internal/auth/config.go new file mode 100644 index 0000000..3205e74 --- /dev/null +++ b/internal/auth/config.go @@ -0,0 +1,15 @@ +package auth + +// AuthConfig pulls AuthN and AuthZ together +type Config struct { + Enabled bool `yaml:"enabled" default:"true" ` + AuthNConfig `mapstructure:",squash"` +} + +// AuthNConfig is the configuration need for the platform to validate tokens +type AuthNConfig struct { + Issuer string `yaml:"issuer" json:"issuer"` + Audience string `yaml:"audience" json:"audience"` + Clients []string `yaml:"clients" json:"clients"` + OIDCConfiguration `yaml:"-" json:"-"` +} diff --git a/internal/auth/discovery.go b/internal/auth/discovery.go new file mode 100644 index 0000000..c1dcdf8 --- /dev/null +++ b/internal/auth/discovery.go @@ -0,0 +1,52 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" +) + +const ( + // DiscoveryPath is the path to the discovery endpoint + DiscoveryPath = "/.well-known/openid-configuration" +) + +// OIDCConfiguration holds the openid configuration for the issuer. +// Currently only required fields are included (https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata) +type OIDCConfiguration struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JwksURI string `json:"jwks_uri"` + ResponseTypesSupported []string `json:"response_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + RequireRequestURIRegistration bool `json:"require_request_uri_registration"` +} + +// DiscoverOPENIDConfiguration discovers the openid configuration for the issuer provided +func DiscoverOIDCConfiguration(ctx context.Context, issuer string) (*OIDCConfiguration, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", issuer, DiscoveryPath), nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to discover idp: %s", resp.Status) + } + defer resp.Body.Close() + + cfg := &OIDCConfiguration{} + err = json.NewDecoder(resp.Body).Decode(&cfg) + if err != nil { + return nil, err + } + + return cfg, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 15b6121..42f4adc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,25 +12,16 @@ import ( "github.com/opentdf/platform/internal/logger" "github.com/opentdf/platform/internal/opa" "github.com/opentdf/platform/internal/server" + "github.com/opentdf/platform/pkg/serviceregistry" "github.com/spf13/viper" ) -type ServiceConfig struct { - Enabled bool `yaml:"enabled"` - Remote RemoteServiceConfig `yaml:"remote"` - ExtraProps map[string]interface{} `json:"-"` -} - -type RemoteServiceConfig struct { - Endpoint string `yaml:"endpoint"` -} - type Config struct { - DB db.Config `yaml:"db"` - OPA opa.Config `yaml:"opa"` - Server server.Config `yaml:"server"` - Logger logger.Config `yaml:"logger"` - Services map[string]ServiceConfig `yaml:"services" default:"{\"policy\": {\"enabled\": true}, \"health\": {\"enabled\": true}, \"authorization\": {\"enabled\": true}, \"wellknown\": {\"enabled\": true}}"` + DB db.Config `yaml:"db"` + OPA opa.Config `yaml:"opa"` + Server server.Config `yaml:"server"` + Logger logger.Config `yaml:"logger"` + Services map[string]serviceregistry.ServiceConfig `yaml:"services" default:"{\"policy\": {\"enabled\": true}, \"health\": {\"enabled\": true}, \"authorization\": {\"enabled\": true}, \"wellknown\": {\"enabled\": true}}"` } type Error string diff --git a/internal/server/server.go b/internal/server/server.go index a021cab..4c6bde3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "github.com/go-chi/cors" protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/opentdf/platform/internal/auth" "github.com/valyala/fasthttp/fasthttputil" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -36,9 +37,11 @@ func (e Error) Error() string { } type Config struct { - Grpc GrpcConfig `yaml:"grpc"` - HTTP HTTPConfig `yaml:"http"` - TLS TLSConfig `yaml:"tls"` + Grpc GrpcConfig `yaml:"grpc"` + HTTP HTTPConfig `yaml:"http"` + TLS TLSConfig `yaml:"tls"` + Auth auth.Config `yaml:"auth"` + WellKnownConfigRegister func(namespace string, config any) error } type GrpcConfig struct { @@ -100,12 +103,10 @@ func NewOpenTDFServer(config Config) (*OpenTDFServer, error) { grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(tlsConfig))) } - grpcOpts = append(grpcOpts, grpc.UnaryInterceptor( - protovalidate_middleware.UnaryServerInterceptor(validator), - )) - - grpcServer := grpc.NewServer( - grpcOpts..., + // Build interceptor chain and handler chain + var ( + interceptors []grpc.UnaryServerInterceptor + handler http.Handler ) grpcInprocess := &inProcessServer{ @@ -117,6 +118,46 @@ func NewOpenTDFServer(config Config) (*OpenTDFServer, error) { runtime.WithHealthzEndpoint(healthpb.NewHealthClient(grpcInprocess.Conn())), ) + handler = mux + + // Add authN interceptor + if config.Auth.Enabled { + authN, err := auth.NewAuthenticator(config.Auth.AuthNConfig) + if err != nil { + return nil, fmt.Errorf("failed to create authentication interceptor: %w", err) + } + + interceptors = append(interceptors, authN.VerifyTokenInterceptor) + handler = authN.VerifyTokenHandler(mux) + + // Try an register oidc issuer to wellknown service but don't return an error if it fails + if err := config.WellKnownConfigRegister("platform_issuer", config.Auth.Issuer); err != nil { + slog.Warn("failed to register platform issuer", slog.String("error", err.Error())) + } + } + + // Add proto validation interceptor + interceptors = append(interceptors, protovalidate_middleware.UnaryServerInterceptor(validator)) + + // Add CORS + // TODO(#305) We need to make cors configurable + handler = cors.New(cors.Options{ + AllowOriginFunc: func(r *http.Request, origin string) bool { return true }, + AllowedMethods: []string{"GET", "POST", "PATCH", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"ACCEPT", "Authorization", "Content-Type", "X-CSRF-Token"}, + ExposedHeaders: []string{"Link"}, + AllowCredentials: true, + MaxAge: maxAge, + }).Handler(handler) + + grpcOpts = append(grpcOpts, grpc.ChainUnaryInterceptor( + interceptors..., + )) + + grpcServer := grpc.NewServer( + grpcOpts..., + ) + // Enable grpc reflection if config.Grpc.ReflectionEnabled { reflection.Register(grpcServer) @@ -128,14 +169,7 @@ func NewOpenTDFServer(config Config) (*OpenTDFServer, error) { WriteTimeout: writeTimeoutSeconds * time.Second, ReadTimeout: readTimeoutSeconds * time.Second, // We need to make cors configurable - Handler: cors.New(cors.Options{ - AllowOriginFunc: func(r *http.Request, origin string) bool { return true }, - AllowedMethods: []string{"GET", "POST", "PATCH", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"ACCEPT", "Authorization", "Content-Type", "X-CSRF-Token"}, - ExposedHeaders: []string{"Link"}, - AllowCredentials: true, - MaxAge: maxAge, - }).Handler(mux), + Handler: handler, TLSConfig: tlsConfig, } } diff --git a/pkg/server/start.go b/pkg/server/start.go index 067d9de..e63eaf0 100644 --- a/pkg/server/start.go +++ b/pkg/server/start.go @@ -77,6 +77,7 @@ func Start(f ...StartOptions) error { defer dbClient.Close() // Create new server for grpc & http. Also will support in process grpc potentially too + conf.Server.WellKnownConfigRegister = wellknown.RegisterConfiguration otdf, err := server.NewOpenTDFServer(conf.Server) if err != nil { slog.Error("issue creating opentdf server", slog.String("error", err.Error())) diff --git a/pkg/server/start_test.go b/pkg/server/start_test.go index 6a75433..fd35f7a 100644 --- a/pkg/server/start_test.go +++ b/pkg/server/start_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/opentdf/platform/internal/auth" "github.com/opentdf/platform/internal/config" "github.com/opentdf/platform/internal/server" "github.com/opentdf/platform/pkg/serviceregistry" @@ -48,6 +49,9 @@ func ServiceRegistrationTest() serviceregistry.Registration { func Test_Start_When_Extra_Service_Registered_Expect_Response(t *testing.T) { // Create new opentdf server s, err := server.NewOpenTDFServer(server.Config{ + Auth: auth.Config{ + Enabled: false, + }, Grpc: server.GrpcConfig{ Port: 43482, }, @@ -64,7 +68,7 @@ func Test_Start_When_Extra_Service_Registered_Expect_Response(t *testing.T) { // Start services with test service err = startServices(config.Config{ - Services: map[string]config.ServiceConfig{ + Services: map[string]serviceregistry.ServiceConfig{ "test": { Enabled: true, }, diff --git a/pkg/serviceregistry/serviceregistry.go b/pkg/serviceregistry/serviceregistry.go index 72b71cf..7a2ead9 100644 --- a/pkg/serviceregistry/serviceregistry.go +++ b/pkg/serviceregistry/serviceregistry.go @@ -8,15 +8,24 @@ import ( "github.com/opentdf/platform/sdk" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/opentdf/platform/internal/config" "github.com/opentdf/platform/internal/db" "github.com/opentdf/platform/internal/opa" "github.com/opentdf/platform/internal/server" "google.golang.org/grpc" ) +type ServiceConfig struct { + Enabled bool `yaml:"enabled"` + Remote RemoteServiceConfig `yaml:"remote"` + ExtraProps map[string]interface{} `json:"-"` +} + +type RemoteServiceConfig struct { + Endpoint string `yaml:"endpoint"` +} + type RegistrationParams struct { - Config config.ServiceConfig + Config ServiceConfig OTDF *server.OpenTDFServer DBClient *db.Client Engine *opa.Engine diff --git a/services/wellknownconfiguration/wellknown_configuration.go b/services/wellknownconfiguration/wellknown_configuration.go index 2d818de..3da6393 100644 --- a/services/wellknownconfiguration/wellknown_configuration.go +++ b/services/wellknownconfiguration/wellknown_configuration.go @@ -3,6 +3,7 @@ package wellknownconfiguration import ( "context" "fmt" + "log/slog" "sync" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -49,6 +50,7 @@ func (s WellKnownService) GetWellKnownConfiguration(context.Context, *wellknown. cfg, err := structpb.NewStruct(wellKnownConfiguration) rwMutex.RUnlock() if err != nil { + slog.Error("failed to create struct for wellknown configuration", slog.String("error", err.Error())) return nil, status.Error(codes.Internal, "failed to create struct for wellknown configuration") } return &wellknown.GetWellKnownConfigurationResponse{