Skip to content

Commit

Permalink
chore: move api token auth state management to backend using cookies (#…
Browse files Browse the repository at this point in the history
…343)

Co-authored-by: UncleGedd <42304551+UncleGedd@users.noreply.github.com>
  • Loading branch information
decleaver and UncleGedd authored Sep 19, 2024
1 parent 38df19a commit 45a4e76
Show file tree
Hide file tree
Showing 21 changed files with 233 additions and 250 deletions.
115 changes: 115 additions & 0 deletions pkg/api/auth/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2024-Present The UDS Authors

package auth

import (
"crypto/rand"
"encoding/hex"
"net/http"
"sync"
)

type InMemoryStorage struct {
sessionID string
mutex sync.RWMutex
}

func NewInMemoryStorage() *InMemoryStorage {
return &InMemoryStorage{}
}

func (s *InMemoryStorage) StoreSession(sessionID string) {
s.mutex.Lock()
defer s.mutex.Unlock()

// Replace the old session with the new one
s.sessionID = sessionID
}

func (s *InMemoryStorage) ValidateSession(sessionID string) bool {
s.mutex.RLock()
defer s.mutex.RUnlock()

// Check if the provided sessionID matches the stored session
return s.sessionID == sessionID
}

func (s *InMemoryStorage) RemoveSession() {
s.mutex.Lock()
defer s.mutex.Unlock()

// Clear the session
s.sessionID = ""
}

var storage = NewInMemoryStorage()

// TokenAuthenticator ensures the request has a valid token.
func TokenAuthenticator(validToken string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if token == "" {
ValidateSessionCookie(next, w, r)
} else if token != validToken {
// If a token is passed in and its not valid, return unauthorized
w.WriteHeader(http.StatusUnauthorized)
return
} else {
// If a token is passed in and its valid, set the session ID and continue
if token != "" && token == validToken {
sessionID := generateSessionID()
storage.StoreSession(sessionID)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
})

next.ServeHTTP(w, r)
}
}
})
}
}

func ValidateSessionCookie(next http.Handler, w http.ResponseWriter, r *http.Request) {
// Retrieve the session cookie
cookie, err := r.Cookie("session_id")
if err != nil || !storage.ValidateSession(cookie.Value) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
}

func generateSessionID() string {
bytes := make([]byte, 16) // 16 bytes = 128 bits
if _, err := rand.Read(bytes); err != nil {
// Handle error
return ""
}
return hex.EncodeToString(bytes)
}

// Connect is a head-only request to test the connection.
func Connect(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}

func RequireJWT(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")

if token == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}

next.ServeHTTP(w, r)
})
}
42 changes: 42 additions & 0 deletions pkg/api/auth/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2024-Present The UDS Authors

package auth

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestStoreSession(t *testing.T) {
storage := NewInMemoryStorage()
sessionID := "test-session-id"

storage.StoreSession(sessionID)

require.Equal(t, sessionID, storage.sessionID, "expected sessionID to be stored correctly")
}

func TestValidateSession(t *testing.T) {
storage := NewInMemoryStorage()
sessionID := "test-session-id"

storage.StoreSession(sessionID)

require.True(t, storage.ValidateSession(sessionID), "expected sessionID to be valid")

invalidSessionID := "invalid-session-id"
require.False(t, storage.ValidateSession(invalidSessionID), "expected invalid sessionID to be invalid")
}

func TestRemoveSession(t *testing.T) {
storage := NewInMemoryStorage()
sessionID := "test-session-id"

storage.StoreSession(sessionID)
storage.RemoveSession()

require.Empty(t, storage.sessionID, "expected sessionID to be empty after removal")
require.False(t, storage.ValidateSession(sessionID), "expected sessionID to be invalid after removal")
}
42 changes: 0 additions & 42 deletions pkg/api/auth/token.go

This file was deleted.

16 changes: 16 additions & 0 deletions pkg/api/middleware/api_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2024-Present The UDS Authors

package middleware

import (
"net/http"

"github.com/defenseunicorns/uds-runtime/pkg/api/auth"
)

func ValidateSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth.ValidateSessionCookie(next, w, r)
})
}
45 changes: 19 additions & 26 deletions pkg/api/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ import (

"strings"

"encoding/json"

"github.com/defenseunicorns/pkg/exec"
"github.com/defenseunicorns/uds-runtime/pkg/api/auth"
_ "github.com/defenseunicorns/uds-runtime/pkg/api/docs" //nolint:staticcheck
Expand Down Expand Up @@ -60,6 +58,18 @@ func Setup(assets *embed.FS) (*chi.Mux, error) {
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)

if authSVC {
r.Use(auth.RequireJWT)
}

// Middleware chain for api token authentication
apiAuthMiddleware := func(next http.Handler) http.Handler {
if apiAuth {
return udsMiddleware.ValidateSession(next)
}
return next
}

// Setup k8s resources
k8sResources, err := setupK8sResources()
if err != nil {
Expand Down Expand Up @@ -94,29 +104,24 @@ func Setup(assets *embed.FS) (*chi.Mux, error) {
http.Redirect(w, r, "/swagger/index.html", http.StatusMovedPermanently)
})
r.Get("/swagger/*", httpSwagger.WrapHandler)
// expose API_AUTH_DISABLED env var to frontend via endpoint
r.Get("/auth-status", serveAuthStatus)
r.Get("/health", checkHealth(k8sResources, disconnected))
r.Route("/api/v1", func(r chi.Router) {
// Require a valid token for API calls
if apiAuth {
// If api auth is enabled, require a valid token for all routes under /api/v1
r.Use(auth.RequireLocalToken(token))
// Endpoint to test if connected with auth
r.Head("/", auth.Connect)
// authenticate token
r.With(auth.TokenAuthenticator(token)).Head("/api-auth", func(_ http.ResponseWriter, _ *http.Request) {})
} else {
r.Head("/api-auth", func(_ http.ResponseWriter, _ *http.Request) {})
}

if authSVC {
r.Use(auth.RequireJWT)
}

r.Route("/monitor", func(r chi.Router) {
r.With(apiAuthMiddleware).Route("/monitor", func(r chi.Router) {
r.Get("/pepr/", monitor.Pepr)
r.Get("/pepr/{stream}", monitor.Pepr)
r.Get("/cluster-overview", monitor.BindClusterOverviewHandler(k8sResources.cache))
})

r.Route("/resources", func(r chi.Router) {
r.With(apiAuthMiddleware).Route("/resources", func(r chi.Router) {
r.Get("/nodes", withLatestCache(k8sResources, getNodes))
r.Get("/nodes/{uid}", withLatestCache(k8sResources, getNode))

Expand Down Expand Up @@ -229,7 +234,7 @@ func Setup(assets *embed.FS) (*chi.Mux, error) {
ip := "127.0.0.1"
colorYellow := "\033[33m"
colorReset := "\033[0m"
url := fmt.Sprintf("http://%s:%s/auth?token=%s", ip, port, token)
url := fmt.Sprintf("http://%s:%s?token=%s", ip, port, token)
log.Printf("%sRuntime API connection: %s%s", colorYellow, url, colorReset)
err := exec.LaunchURL(url)
if err != nil {
Expand Down Expand Up @@ -346,18 +351,6 @@ func checkForClusterAuth() bool {
return authSVC
}

func serveAuthStatus(w http.ResponseWriter, _ *http.Request) {
authStatus := map[string]string{
"API_AUTH_DISABLED": os.Getenv("API_AUTH_DISABLED"),
}

w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(authStatus)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

// withLatestCache returns a wrapper lambda function, creating a closure that can dynamically access the latest cache
func withLatestCache(k8sResources *K8sResources, handler func(cache *resources.Cache) func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
Expand Down
7 changes: 1 addition & 6 deletions ui/src/lib/components/k8s/DataTable/component.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import type { Row as NamespaceRow } from '$features/k8s/namespaces/store'
import { type ResourceStoreInterface } from '$features/k8s/types'
import { addToast } from '$features/toast'
import { apiAuthEnabled } from '$lib/features/api-auth/store'
import { ChevronDown, ChevronUp, Filter, Information, Search } from 'carbon-icons-svelte'
// Determine if the data is namespaced
Expand Down Expand Up @@ -68,11 +67,7 @@
// Fetch the resource data
let results
if (!apiAuthEnabled) {
results = await fetch(`${apiPath}/${uid}`)
} else {
results = await fetch(`${apiPath}/${uid}?token=${sessionStorage.getItem('token')}`)
}
results = await fetch(`${apiPath}/${uid}`)
// If the fetch is successful, set the resource data
if (results.ok) {
Expand Down
1 change: 0 additions & 1 deletion ui/src/lib/features/api-auth/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
import { writable } from 'svelte/store'

export const authenticated = writable(false)
export const apiAuthEnabled = writable<null | boolean>(null)
3 changes: 1 addition & 2 deletions ui/src/lib/features/k8s/cluster-overview/component.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import { onMount } from 'svelte'
import { Card, LinkCard, ProgressBar } from '$components'
import { createEventSource } from '$lib/utils/helpers'
import ApexCharts from 'apexcharts'
import type { ApexOptions } from 'apexcharts'
Expand Down Expand Up @@ -200,7 +199,7 @@
onMount(() => {
const path: string = `/api/v1/monitor/cluster-overview`
const overview = createEventSource(path)
const overview = new EventSource(path)
overview.onmessage = (event) => {
clusterData = JSON.parse(event.data) as ClusterData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import {
type K8StatusMapping,
type ResourceStoreInterface,
} from '$features/k8s/types'
import { createEventSource } from '$lib/utils/helpers'

interface Row extends CommonRow {
storage_class: string
Expand All @@ -31,7 +30,7 @@ export function createStore(): ResourceStoreInterface<Resource, Row> {
const podStore = writable<number>()
const jsonPathFields = 'metadata.name,spec.volumes,status.phase'
const podEventsPath = `/api/v1/resources/workloads/pods?fields=${jsonPathFields}`
const podEvents = createEventSource(podEventsPath)
const podEvents = new EventSource(podEventsPath)

podEvents.onmessage = (event) => {
const data = JSON.parse(event.data) as V1Pod[]
Expand Down
3 changes: 1 addition & 2 deletions ui/src/lib/features/k8s/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import { derived, writable, type Writable } from 'svelte/store'

import type { KubernetesObject } from '@kubernetes/client-node'
import { createEventSource } from '$lib/utils/helpers'
import { differenceInDays, differenceInHours, differenceInMinutes, differenceInSeconds } from 'date-fns'

import { SearchByType, type CommonRow, type ResourceStoreInterface, type ResourceWithTable } from './types'
Expand Down Expand Up @@ -185,7 +184,7 @@ export class ResourceStore<T extends KubernetesObject, U extends CommonRow> impl

this.#initialized = true

this.#eventSource = createEventSource(this.url)
this.#eventSource = new EventSource(this.url)

this.#eventSource.onmessage = ({ data }) => {
try {
Expand Down
Loading

0 comments on commit 45a4e76

Please sign in to comment.