From b365bda2a59640f79117f38ec84c23997b913b10 Mon Sep 17 00:00:00 2001 From: Tejas Dinkar Date: Thu, 27 Aug 2020 23:13:03 +0530 Subject: [PATCH] fix: Online Restore honors credentials passed in (#6295) (cherry picked from commit a8a6e85b790e3464fefd77ab6db95b23c281149f) --- worker/backup_handler.go | 5 ++--- worker/online_restore_ee.go | 11 ++++++++++- worker/restore.go | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/worker/backup_handler.go b/worker/backup_handler.go index 55c4cde4455..1e5aedcf14e 100644 --- a/worker/backup_handler.go +++ b/worker/backup_handler.go @@ -148,14 +148,13 @@ type loadFn func(reader io.Reader, groupId uint32, preds predicateSet) (uint64, // LoadBackup will scan location l for backup files in the given backup series and load them // sequentially. Returns the maximum Since value on success, otherwise an error. -func LoadBackup(location, backupId string, fn loadFn) LoadResult { +func LoadBackup(location, backupId string, creds *Credentials, fn loadFn) LoadResult { uri, err := url.Parse(location) if err != nil { return LoadResult{0, 0, err} } - // TODO(martinmr): allow overriding credentials during restore. - h := getHandler(uri.Scheme, nil) + h := getHandler(uri.Scheme, creds) if h == nil { return LoadResult{0, 0, errors.Errorf("Unsupported URI: %v", uri)} } diff --git a/worker/online_restore_ee.go b/worker/online_restore_ee.go index c4bef0194cf..7616431b5aa 100644 --- a/worker/online_restore_ee.go +++ b/worker/online_restore_ee.go @@ -285,8 +285,17 @@ func getEncConfig(req *pb.RestoreRequest) (*viper.Viper, error) { return config, nil } +func getCredentialsFromRestoreRequest(req *pb.RestoreRequest) *Credentials { + return &Credentials{ + AccessKey: req.AccessKey, + SecretKey: req.SecretKey, + SessionToken: req.SessionToken, + Anonymous: req.Anonymous, + } +} + func writeBackup(ctx context.Context, req *pb.RestoreRequest) error { - res := LoadBackup(req.Location, req.BackupId, + res := LoadBackup(req.Location, req.BackupId, getCredentialsFromRestoreRequest(req), func(r io.Reader, groupId uint32, preds predicateSet) (uint64, error) { if groupId != req.GroupId { // LoadBackup will try to call the backup function for every group. diff --git a/worker/restore.go b/worker/restore.go index fe0354e0cb7..6bbb7129c8a 100644 --- a/worker/restore.go +++ b/worker/restore.go @@ -43,7 +43,7 @@ func RunRestore(pdir, location, backupId string, key x.SensitiveByteSlice) LoadR // Scan location for backup files and load them. Each file represents a node group, // and we create a new p dir for each. - return LoadBackup(location, backupId, + return LoadBackup(location, backupId, nil, func(r io.Reader, groupId uint32, preds predicateSet) (uint64, error) { dir := filepath.Join(pdir, fmt.Sprintf("p%d", groupId))