Skip to content

Commit

Permalink
fix: reconnection handling after introducing TLS (#412)
Browse files Browse the repository at this point in the history
Co-authored-by: UncleGedd <42304551+UncleGedd@users.noreply.github.com>
  • Loading branch information
TristanHoladay and UncleGedd authored Oct 8, 2024
1 parent aaa2fed commit b89cf16
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 288 deletions.
4 changes: 2 additions & 2 deletions pkg/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,6 @@ func getStorageClass(cache *resources.Cache) func(w http.ResponseWriter, r *http
// @Produce json
// @Success 200
// @Router /health [get]
func checkClusteConnection(k8sSession *session.K8sSession, disconnected chan error) http.HandlerFunc {
return session.MonitorConnection(k8sSession, disconnected)
func checkClusteConnection(k8sSession *session.K8sSession) http.HandlerFunc {
return k8sSession.ServeConnStatus()
}
2 changes: 1 addition & 1 deletion pkg/api/resources/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func NewCache(ctx context.Context, clients *client.Clients) (*Cache, error) {
go c.factory.Start(c.stopper)
go c.dynamicFactory.Start(c.stopper)

// Wait for the pod cache to sync as they it is required for metrics collection
// Wait for the pod cache to sync as it is required for metrics collection
if !cache.WaitForCacheSync(ctx.Done(), c.Pods.HasSynced) {
return nil, fmt.Errorf("timed out waiting for caches to sync")
}
Expand Down
10 changes: 3 additions & 7 deletions pkg/api/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
udsMiddleware "github.com/defenseunicorns/uds-runtime/pkg/api/middleware"
"github.com/defenseunicorns/uds-runtime/pkg/api/monitor"
"github.com/defenseunicorns/uds-runtime/pkg/api/resources"
"github.com/defenseunicorns/uds-runtime/pkg/k8s/client"
"github.com/defenseunicorns/uds-runtime/pkg/k8s/session"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
Expand All @@ -45,17 +44,14 @@ func Setup(assets *embed.FS) (*chi.Mux, bool, error) {

inCluster := k8sSession.InCluster

// Create the disconnected channel
disconnected := make(chan error)

if !inCluster {
apiAuth, token, err = checkForLocalAuth()
if err != nil {
return nil, inCluster, fmt.Errorf("failed to set auth: %w", err)
}

// Start the reconnection goroutine
go k8sSession.HandleReconnection(disconnected, client.NewClient, resources.NewCache)
// Start the cluster monitoring goroutine
go k8sSession.StartClusterMonitoring()
}

authSVC := checkForClusterAuth()
Expand Down Expand Up @@ -83,7 +79,7 @@ func Setup(assets *embed.FS) (*chi.Mux, bool, error) {
http.Redirect(w, r, "/swagger/index.html", http.StatusMovedPermanently)
})
r.Get("/swagger/*", httpSwagger.WrapHandler)
r.Get("/health", checkClusteConnection(k8sSession, disconnected))
r.Get("/health", checkClusteConnection(k8sSession))
r.Route("/api/v1", func(r chi.Router) {
// Require a valid token for API calls
if apiAuth {
Expand Down
199 changes: 96 additions & 103 deletions pkg/k8s/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package session

import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
Expand All @@ -18,10 +17,14 @@ import (
type K8sSession struct {
Clients *client.Clients
Cache *resources.Cache
Cancel context.CancelFunc
CurrentCtx string
CurrentCluster string
Cancel context.CancelFunc
Status chan string
InCluster bool
ready bool
createCache createCache
createClient createClient
}

type createClient func() (*client.Clients, error)
Expand Down Expand Up @@ -64,54 +67,82 @@ func CreateK8sSession() (*K8sSession, error) {
CurrentCluster: currentCluster,
Cancel: cancel,
InCluster: inCluster,
Status: make(chan string),
ready: true,
createCache: resources.NewCache,
createClient: client.NewClient,
}

return session, nil
}

// HandleReconnection is a goroutine that handles reconnection to the k8s API
// passing createClient and createCache instead of calling clients.NewClient and resources.NewCache for testing purposes
func (ks *K8sSession) HandleReconnection(disconnected chan error, createClient createClient,
createCache createCache) {
for err := range disconnected {
log.Printf("Disconnected error received: %v\n", err)
for {
// Cancel the previous context
ks.Cancel()
time.Sleep(getRetryInterval())

currentCtx, currentCluster, err := client.GetCurrentContext()
if err != nil {
log.Printf("Error fetching current context: %v\n", err)
continue
}
// StartClusterMonitoring is a goroutine that checks the connection to the cluster
func (ks *K8sSession) StartClusterMonitoring() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

// If the current context or cluster is different from the original, skip reconnection
if currentCtx != ks.CurrentCtx || currentCluster != ks.CurrentCluster {
log.Println("Current context has changed. Skipping reconnection.")
continue
}
for range ticker.C {
// Skip if not ready, prevents race conditions using new cache
if !ks.ready {
continue
}
// Perform cluster health check
_, err := ks.Clients.Clientset.ServerVersion()
if err != nil {
ks.Status <- "error"
ks.HandleReconnection()
} else {
ks.Status <- "success"
}
}
}

k8sClient, err := createClient()
if err != nil {
log.Printf("Retrying to create k8s client: %v\n", err)
continue
}
// HandleReconnection infinitely retries to re-create the client and cache of the formerly connected cluster
func (ks *K8sSession) HandleReconnection() {
log.Println("Disconnected error received")

// Create a new context and cache
ctx, cancel := context.WithCancel(context.Background())
cache, err := createCache(ctx, k8sClient)
if err != nil {
log.Printf("Retrying to create cache: %v\n", err)
continue
}
// Set ready to false to block cluster check ticker
ks.ready = false

ks.Clients = k8sClient
ks.Cache = cache
ks.Cancel = cancel
log.Println("Successfully reconnected to k8s and recreated cache")
break
for {
// Cancel the previous context
ks.Cancel()
time.Sleep(getRetryInterval())

currentCtx, currentCluster, err := client.GetCurrentContext()
if err != nil {
log.Printf("Error fetching current context: %v\n", err)
continue
}

// If the current context or cluster is different from the original, skip reconnection
if currentCtx != ks.CurrentCtx || currentCluster != ks.CurrentCluster {
log.Println("Current context has changed. Skipping reconnection.")
continue
}

k8sClient, err := ks.createClient()
if err != nil {
log.Printf("Retrying to create k8s client: %v\n", err)
continue
}

// Create a new context and cache
ctx, cancel := context.WithCancel(context.Background())
cache, err := ks.createCache(ctx, k8sClient)
if err != nil {
log.Printf("Retrying to create cache: %v\n", err)
continue
}

ks.Clients = k8sClient
ks.Cache = cache
ks.Cancel = cancel

ks.ready = true
log.Println("Successfully reconnected to cluster and recreated cache")

break
}
}

Expand All @@ -126,83 +157,45 @@ func getRetryInterval() time.Duration {
return 5 * time.Second // Default to 5 seconds if not set
}

func MonitorConnection(k8sSession *K8sSession, disconnected chan error) http.HandlerFunc {
// ServeConnStatus returns a handler function that streams the connection status to the client
func (ks *K8sSession) ServeConnStatus() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set headers to keep connection alive
rest.WriteHeaders(w)

ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

recovering := false

// Function to check the cluster health when running out of cluster
checkCluster := func() {
versionInfo, err := k8sSession.Clients.Clientset.ServerVersion()
response := map[string]string{}

// if err then connection is lost
if err != nil {
response["error"] = err.Error()
w.WriteHeader(http.StatusInternalServerError)
disconnected <- err
// indicate that the reconnection handler should have been triggered by the disconnected channel
recovering = true
} else if recovering {
// if errors are resolved, send a reconnected message
response["reconnected"] = versionInfo.String()
recovering = false
} else {
response["success"] = versionInfo.String()
w.WriteHeader(http.StatusOK)
}

data, err := json.Marshal(response)
if err != nil {
http.Error(w, fmt.Sprintf("data: Error: %v\n\n", err), http.StatusInternalServerError)
return
}

// Write the data to the response
fmt.Fprintf(w, "data: %s\n\n", data)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

// Flush the response to ensure it is sent to the client
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
// If running in cluster don't check connection
if ks.InCluster {
fmt.Fprint(w, "event: close\ndata: in-cluster\n\n")
flusher.Flush()
}

// If running in cluster don't check for version and send error or reconnected events
if k8sSession.InCluster {
checkCluster = func() {
response := map[string]string{
"success": "in-cluster",
}
data, err := json.Marshal(response)
if err != nil {
http.Error(w, fmt.Sprintf("data: Error: %v\n\n", err), http.StatusInternalServerError)
return
}
// Write the data to the response
fmt.Fprintf(w, "data: %s\n\n", data)

// Flush the response to ensure it is sent to the client
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
sendStatus := func(msg string) {
fmt.Fprintf(w, "data: %s\n\n", msg)
flusher.Flush()
}

// Check the cluster immediately
checkCluster()
// To mitigate timing between connection start and getting status updates, immediately check cluster connection
_, err := ks.Clients.Clientset.ServerVersion()
if err != nil {
sendStatus("error")
} else {
sendStatus("success")
}

// Listen for updates and send them to the client
for {
select {
case <-ticker.C:
checkCluster()
case msg := <-ks.Status:
sendStatus(msg)

case <-r.Context().Done():
// Client closed the connection
// Client disconnected
return
}
}
Expand Down
Loading

0 comments on commit b89cf16

Please sign in to comment.