Skip to content

Commit

Permalink
Implement Keysync Backup/Restore
Browse files Browse the repository at this point in the history
This is a first pass at implementing an encrypted backup feature for keysync.
This writes an AES-GCM encrypted tar file to a configured disk location,
containing the secrets keysync has written out to the regular secrets location.

The Keyrestore program is updated to take these encrypted files and restore
them.

This is useful if you want to avoid writing plaintext secrets to disk, but want
a way to reboot or recover servers without Keywhiz running (eg, in an outage
scenario).

Right now a fixed AES key is used.  It is loaded after syncing completes, so
you can use a key stored in Keywhiz.  Recovering that key is left as an
exercise to the reader, or future improvements to this tool.
  • Loading branch information
mcpherrinm committed Sep 3, 2019
1 parent 6d9361f commit 1576878
Show file tree
Hide file tree
Showing 23 changed files with 735 additions and 111 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
fixtures/clients/client[1-4]
testing/secrets
keysync
keyrestore
/keysync
/cmd/keysync/keysync
/keyrestore
/cmd/keyrestore/keyrestore
3 changes: 2 additions & 1 deletion Dockerfile-test
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ RUN go mod download
COPY . /opt/keysync

WORKDIR /opt/keysync/cmd/keysync

RUN go build -o /usr/bin/keysync

WORKDIR /opt/keysync/cmd/keyrestore
RUN go build -o /usr/bin/keyrestore

CMD /opt/keysync/testing/run-tests.sh
68 changes: 47 additions & 21 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"net/http/pprof"
"time"

"github.com/square/keysync/backup"

"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
sqmetrics "github.com/square/go-sq-metrics"
Expand All @@ -38,19 +40,21 @@ const (

// APIServer holds state needed for responding to HTTP api requests
type APIServer struct {
backup backup.Backup
syncer *Syncer
logger *logrus.Entry
}

// StatusResponse from API endpoints
type StatusResponse struct {
Ok bool `json:"ok"`
Message string `json:"message"`
Ok bool `json:"ok"`
Message string `json:"message,omitempty"`
Updated *Updated `json:"updated,omitempty"`
}

func writeSuccess(w http.ResponseWriter) {
resp := &StatusResponse{Ok: true}
out, _ := json.MarshalIndent(resp, "", "")
func writeSuccess(w http.ResponseWriter, updated *Updated) {
resp := &StatusResponse{Ok: true, Updated: updated}
out, _ := json.MarshalIndent(resp, "", " ")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(out)
_, _ = w.Write([]byte("\n"))
Expand All @@ -66,15 +70,15 @@ func writeError(w http.ResponseWriter, status int, err error) {

func (a *APIServer) syncAll(w http.ResponseWriter, r *http.Request) {
a.logger.Info("Syncing all from API")
errors := a.syncer.RunOnce()
if len(errors) != 0 {
err := fmt.Errorf("errors: %v", errors)
updated, errs := a.syncer.RunOnce()
if len(errs) != 0 {
err := fmt.Errorf("errors: %v", errs)
a.logger.WithError(err).Warn("error syncing")
writeError(w, http.StatusInternalServerError, err)
return
}

writeSuccess(w)
writeSuccess(w, &updated)
}

func (a *APIServer) syncOne(w http.ResponseWriter, r *http.Request) {
Expand All @@ -100,22 +104,41 @@ func (a *APIServer) syncOne(w http.ResponseWriter, r *http.Request) {
// below cases we end up in.
defer pendingCleanup.cleanup(a.logger)

var updated Updated
if syncerEntry, ok := a.syncer.clients[client]; ok {
if err = syncerEntry.Sync(); err != nil {
updated, err = syncerEntry.Sync()
if err != nil {
logger.WithError(err).Warnf("Error syncing %s", client)
writeError(w, http.StatusInternalServerError, fmt.Errorf("error syncing %s: %s", client, err))
return
}
} else {
if _, pending := pendingCleanup.Outputs[client]; !pending {
// If it's not a current client, or one pending cleanup, return an error
logger.Infof("Unknown client: %s", client)
writeError(w, http.StatusNotFound, fmt.Errorf("unknown client: %s", client))
return
}
} else if _, pending := pendingCleanup.Outputs[client]; !pending {
// If it's not a current client, or one pending cleanup, return an error
logger.Infof("Unknown client: %s", client)
writeError(w, http.StatusNotFound, fmt.Errorf("unknown client: %s", client))
return
}

logger.WithFields(logrus.Fields{
"Added": updated.Added,
"Changed": updated.Changed,
"Deleted": updated.Deleted,
}).Info("API requested sync complete")

writeSuccess(w, &updated)
}

func (a *APIServer) runBackup(w http.ResponseWriter, r *http.Request) {
if a.backup == nil {
writeError(w, http.StatusServiceUnavailable, errors.New("Backups not configured"))
return
}

writeSuccess(w)
if err := a.backup.Backup(); err != nil {
writeError(w, http.StatusInternalServerError, err)
} else {
writeSuccess(w, nil)
}
}

func (a *APIServer) status(w http.ResponseWriter, r *http.Request) {
Expand All @@ -132,7 +155,7 @@ func (a *APIServer) status(w http.ResponseWriter, r *http.Request) {
return
}

writeSuccess(w)
writeSuccess(w, nil)
}

// handle wraps the HandlerFunc with logging, and registers it in the given router.
Expand All @@ -149,9 +172,9 @@ func handle(router *mux.Router, path string, methods []string, fn http.HandlerFu
}

// NewAPIServer is the constructor for an APIServer
func NewAPIServer(syncer *Syncer, port uint16, baseLogger *logrus.Entry, metrics *sqmetrics.SquareMetrics) {
func NewAPIServer(syncer *Syncer, backup backup.Backup, port uint16, baseLogger *logrus.Entry, metrics *sqmetrics.SquareMetrics) {
logger := baseLogger.WithField("logger", "api_server")
apiServer := APIServer{syncer: syncer, logger: logger}
apiServer := APIServer{syncer: syncer, logger: logger, backup: backup}
router := mux.NewRouter()

// Debug endpoints
Expand All @@ -164,6 +187,9 @@ func NewAPIServer(syncer *Syncer, port uint16, baseLogger *logrus.Entry, metrics
handle(router, "/sync", httpPost, apiServer.syncAll, logger)
handle(router, "/sync/{client}", httpPost, apiServer.syncOne, logger)

// Create backup
handle(router, "/backup", httpPost, apiServer.runBackup, logger)

// Status and metrics endpoints
router.HandleFunc("/status", apiServer.status).Methods(httpGet...)
handle(router, "/metrics", httpGet, metrics.ServeHTTP, logger)
Expand Down
86 changes: 82 additions & 4 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"testing"
"time"

"github.com/square/keysync/backup"

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -70,7 +72,7 @@ func TestApiSyncAllAndSyncClientSuccess(t *testing.T) {
syncer, err := createNewSyncer("fixtures/configs/test-config.yaml", server)
require.Nil(t, err)

NewAPIServer(syncer, port, logrus.NewEntry(logrus.New()), metricsForTest())
NewAPIServer(syncer, nil, port, logrus.NewEntry(logrus.New()), metricsForTest())
waitForServer(t, port)

// Test SyncAll success
Expand Down Expand Up @@ -137,7 +139,7 @@ func TestApiSyncOneError(t *testing.T) {
_, err = syncer.LoadClients()
assert.NotNil(t, err)

NewAPIServer(syncer, port, logrus.NewEntry(logrus.New()), metricsForTest())
NewAPIServer(syncer, nil, port, logrus.NewEntry(logrus.New()), metricsForTest())
waitForServer(t, port)

// Test error loading clients when syncing single client
Expand All @@ -157,6 +159,82 @@ func TestApiSyncOneError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
}

// Ensure the /backup path returns an error if no backup is configured
func TestNoBackup(t *testing.T) {
if testing.Short() {
t.Skip("Skipping API test in short mode.")
}

port := randomPort()

server := createDefaultServer()
defer server.Close()

// Load a test config
syncer, err := createNewSyncer("fixtures/configs/test-config.yaml", server)
require.Nil(t, err)

NewAPIServer(syncer, nil, port, logrus.NewEntry(logrus.New()), metricsForTest())
waitForServer(t, port)

// Call the /backup API, which should return an error
req, err := http.NewRequest("POST", fmt.Sprintf("http://localhost:%d/backup", port), nil)
require.Nil(t, err)

res, err := http.DefaultClient.Do(req)
require.Nil(t, err)

require.EqualValues(t, http.StatusServiceUnavailable, res.StatusCode)
}

// stubBackup is used to verify the API server calls the backup object
type stubBackup struct {
backupCalls int
}

var _ backup.Backup = &stubBackup{}

func (b *stubBackup) Backup() error {
b.backupCalls++
return nil
}

func (b *stubBackup) Restore() error {
return nil
}

func TestBackup(t *testing.T) {
if testing.Short() {
t.Skip("Skipping API test in short mode.")
}

port := randomPort()

server := createDefaultServer()
defer server.Close()

// Load a test config
syncer, err := createNewSyncer("fixtures/configs/test-config.yaml", server)
require.Nil(t, err)

stub := stubBackup{}

NewAPIServer(syncer, &stub, port, logrus.NewEntry(logrus.New()), metricsForTest())
waitForServer(t, port)

// Call the /backup API, which should call the Backup() method
req, err := http.NewRequest("POST", fmt.Sprintf("http://localhost:%d/backup", port), nil)
require.Nil(t, err)

res, err := http.DefaultClient.Do(req)
require.Nil(t, err)

require.EqualValues(t, http.StatusOK, res.StatusCode)

// The backup object should have been called exactly once
require.EqualValues(t, 1, stub.backupCalls)
}

func TestHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("Skipping API test in short mode.")
Expand All @@ -173,7 +251,7 @@ func TestHealthCheck(t *testing.T) {
_, err = syncer.LoadClients()
assert.NotNil(t, err)

NewAPIServer(syncer, port, logrus.NewEntry(logrus.New()), metricsForTest())
NewAPIServer(syncer, nil, port, logrus.NewEntry(logrus.New()), metricsForTest())
waitForServer(t, port)

// 1. Check that health check returns false if we've never had a success
Expand Down Expand Up @@ -211,7 +289,7 @@ func TestMetricsReporting(t *testing.T) {
_, err = syncer.LoadClients()
assert.NotNil(t, err)

NewAPIServer(syncer, port, logrus.NewEntry(logrus.New()), metricsForTest())
NewAPIServer(syncer, nil, port, logrus.NewEntry(logrus.New()), metricsForTest())
waitForServer(t, port)

// Check health under good conditions
Expand Down
91 changes: 91 additions & 0 deletions backup/backup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// package backup handles reading and writing encrypted .tar files from the secretsDirectory to
// a backupPath using the key backupKey.
package backup

import (
"encoding/hex"
"io/ioutil"

"github.com/square/keysync/output"

"github.com/pkg/errors"
)

type Backup interface {
Backup() error
Restore() error
}

type FileBackup struct {
SecretsDirectory string
BackupPath string
KeyPath string
Chown bool
EnforceFS output.Filesystem
}

// Backup is intended to be implemented by FileBackup
var _ Backup = &FileBackup{}

func (b *FileBackup) loadKey() ([]byte, error) {
keyhex, err := ioutil.ReadFile(b.KeyPath)
if err != nil {
return nil, err
}
key := make([]byte, hex.DecodedLen(len(keyhex)))
if _, err := hex.Decode(key, keyhex); err != nil {
return nil, err
}
return key, nil
}

// Backup loads all files in b.SecretsDirectory, tars, compresses, then encrypts with b.BackupKey
// The content is written to b.BackupPath
func (b *FileBackup) Backup() error {
tarball, err := createTar(b.SecretsDirectory)
if err != nil {
return err
}

key, err := b.loadKey()
if err != nil {
return err
}

// Encrypt it
encrypted, err := encrypt(tarball, key)
if err != nil {
return errors.Wrap(err, "error encrypting backup")
}

// We always write as r-- --- ---, aka 0400
// UID/GID in this struct are ignored as chownFiles: false
perms := output.FileInfo{Mode: 0400}
// Write it out, and if it errored, wrapped the error
_, err = output.WriteFileAtomically(b.BackupPath, false, perms, 0, encrypted)
return err
}

// Restore opens b.BackupPath, decrypts with b.BackupKey, and writes contents to b.SecretsDirectory
func (b *FileBackup) Restore() error {
ciphertext, err := ioutil.ReadFile(b.BackupPath)
if err != nil {
return errors.Wrap(err, "error reading backup")
}

key, err := b.loadKey()
if err != nil {
return err
}

tarball, err := decrypt(ciphertext, key)
if err != nil {
return errors.Wrap(err, "error decrypting backup")
}

if err := extractTar(tarball, b.Chown, b.SecretsDirectory, b.EnforceFS); err != nil {
return errors.Wrap(err, "Error extracting tarball")
}

return nil
}
Loading

0 comments on commit 1576878

Please sign in to comment.