From 20c682db9fea55446406b4bd0dc4457f5585f81b Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 17 Sep 2020 07:43:04 +0800 Subject: [PATCH 01/37] Encrypt region boundary keys, Part 2 - server changes Signed-off-by: Yi Wu --- conf/config.toml | 47 ++++++++++ go.mod | 1 + pkg/encryption/config.go | 141 ++++++++++++++++++++++++++++ pkg/encryption/kms.go | 19 ++++ pkg/mock/mockcluster/mockcluster.go | 2 +- server/cluster/coordinator.go | 7 +- server/config/config.go | 11 ++- server/core/region_storage.go | 94 +++++++++++++++---- server/core/storage.go | 59 +++++++----- server/encryption/key_manager.go | 54 +++++++++++ server/server.go | 18 +++- 11 files changed, 405 insertions(+), 48 deletions(-) create mode 100644 pkg/encryption/config.go create mode 100644 pkg/encryption/kms.go create mode 100644 server/encryption/key_manager.go diff --git a/conf/config.toml b/conf/config.toml index 11c45d2ddea..d00fbf3ba2f 100644 --- a/conf/config.toml +++ b/conf/config.toml @@ -32,6 +32,53 @@ key-path = "" cert-allowed-cn = ["example.com"] +[security.encryption] +## Encryption method to use for PD data. One of "plaintext", "aes128-ctr", "aes192-ctr" and "aes256-ctr". +## Defaults to "plaintext" if not set. +# data-encryption-method = "plaintext" +## Specifies how often PD rotates data encryption key. +# data-key-rotation-period = "7d" + +## Specifies master key if encryption is enabled. There are three types of master key: +## +## * "plaintext": +## +## Plaintext as master key means no master key is given and only applicable when +## encryption is not enabled, i.e. data-encryption-method = "plaintext". This type doesn't +## have sub-config items. Example: +## +## [security.encryption.master-key] +## type = "plaintext" +## +## * "kms": +## +## Use a KMS service to supply master key. Currently only AWS KMS is supported. This type of +## master key is recommended for production use. Example: +## +## [security.encryption.master-key] +## type = "kms" +## ## KMS CMK key id. Must be a valid KMS CMK where the TiKV process has access to. +## ## In production is recommended to grant access of the CMK to TiKV using IAM. +## key-id = "1234abcd-12ab-34cd-56ef-1234567890ab" +## ## AWS region of the KMS CMK. +## region = "us-west-2" +## ## (Optional) AWS KMS service endpoint. Only required when non-default KMS endpoint is +## ## desired. +## endpoint = "https://kms.us-west-2.amazonaws.com" +## +## * "file": +## +## Supply a custom encryption key stored in a file. It is recommended NOT to use in production, +## as it breaks the purpose of encryption at rest, unless the file is stored in tempfs. +## The file must contain a 256-bits (32 bytes, regardless of key length implied by +## data-encryption-method) key encoded as hex string and end with newline ("\n"). Example: +## +## [security.encryption.master-key] +## type = "file" +## path = "/path/to/master/key/file" +# [security.encryption.master-key] +# type = "plaintext" + [log] level = "info" diff --git a/go.mod b/go.mod index da43ee53bec..c2bbd32bc19 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/pingcap/kvproto v0.0.0-20200916031750-f9473f2c5379 github.com/pingcap/log v0.0.0-20200511115504-543df19646ad github.com/pingcap/sysutil v0.0.0-20200715082929-4c47bcac246a + github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.0.0 github.com/prometheus/common v0.4.1 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/pkg/encryption/config.go b/pkg/encryption/config.go new file mode 100644 index 00000000000..6c048b18e96 --- /dev/null +++ b/pkg/encryption/config.go @@ -0,0 +1,141 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package encryption + +import ( + "time" + + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pkg/errors" + "github.com/tikv/pd/pkg/typeutil" +) + +const ( + methodPlaintext = "plaintext" + methodAes128Ctr = "aes128-ctr" + methodAes192Ctr = "aes192-ctr" + methodAes256Ctr = "aes256-ctr" + + masterKeyTypePlaintext = "plaintext" + masterKeyTypeKMS = "kms" + masterKeyTypeFile = "file" + + defaultDataEncryptionMethod = methodPlaintext + defaultDataKeyRotationPeriod = "168h" // 7 days +) + +type Config struct { + // Encryption method to use for PD data. + DataEncryptionMethod string `toml:"data-encryption-method" json:"data-encryption-method"` + // Specifies how often PD rotates data encryption key. + DataKeyRotationPeriod typeutil.Duration `toml:"data-key-rotation-period" json:"data-key-rotation-period"` + // Specifies master key if encryption is enabled. + MasterKey MasterKeyConfig `toml:"master-key" json:"master-key"` +} + +func (c *Config) Adjust() error { + if len(c.DataEncryptionMethod) == 0 { + c.DataEncryptionMethod = methodPlaintext + } else { + if _, err := c.GetMethod(); err != nil { + return err + } + } + if c.DataKeyRotationPeriod.Duration == 0 { + duration, err := time.ParseDuration(defaultDataKeyRotationPeriod) + if err != nil { + return errors.Wrapf(err, "fail to parse default value of data-key-rotation-period %s", + defaultDataKeyRotationPeriod) + } + c.DataKeyRotationPeriod.Duration = duration + } + if len(c.MasterKey.Type) == 0 { + c.MasterKey.Type = masterKeyTypePlaintext + } else { + if _, err := c.GetMasterKey(); err != nil { + return err + } + } + return nil +} + +func (c *Config) GetMethod() (encryptionpb.EncryptionMethod, error) { + switch c.DataEncryptionMethod { + case methodPlaintext: + return encryptionpb.EncryptionMethod_PLAINTEXT, nil + case methodAes128Ctr: + return encryptionpb.EncryptionMethod_AES128_CTR, nil + case methodAes192Ctr: + return encryptionpb.EncryptionMethod_AES192_CTR, nil + case methodAes256Ctr: + return encryptionpb.EncryptionMethod_AES256_CTR, nil + default: + return encryptionpb.EncryptionMethod_UNKNOWN, + errors.Errorf("invalid encryption method %s", c.DataEncryptionMethod) + } +} + +func (c *Config) GetMasterKey() (*encryptionpb.MasterKey, error) { + switch c.MasterKey.Type { + case masterKeyTypePlaintext: + return &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Plaintext{ + Plaintext: &encryptionpb.MasterKeyPlaintext{}, + }, + }, nil + case masterKeyTypeKMS: + return &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: kmsVendorAWS, + KeyId: c.MasterKey.KmsKeyId, + Region: c.MasterKey.KmsRegion, + Endpoint: c.MasterKey.KmsEndpoint, + }, + }, + }, nil + case masterKeyTypeFile: + return &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: c.MasterKey.FilePath, + }, + }, + }, nil + default: + return nil, errors.Errorf("unrecognized encryption master key type: %s", c.MasterKey.Type) + } +} + +type MasterKeyConfig struct { + // Master key type, one of "plaintext", "kms" or "file". + Type string `toml:"type" json:"type"` + + MasterKeyKMSConfig + MasterKeyFileConfig +} + +type MasterKeyKMSConfig struct { + // KMS CMK key id. + KmsKeyId string `toml:"key-id" json:"key-id"` + // KMS region of the CMK. + KmsRegion string `toml:"region" json:"region"` + // Custom endpoint to access KMS. + KmsEndpoint string `toml:"endpoint" json:"endpoint"` +} + +type MasterKeyFileConfig struct { + // Master key file path. + FilePath string `toml:"path" json:"path"` +} diff --git a/pkg/encryption/kms.go b/pkg/encryption/kms.go new file mode 100644 index 00000000000..eb613cceb14 --- /dev/null +++ b/pkg/encryption/kms.go @@ -0,0 +1,19 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package encryption + +const ( + // We only support AWS KMS right now. + kmsVendorAWS = "AWS" +) diff --git a/pkg/mock/mockcluster/mockcluster.go b/pkg/mock/mockcluster/mockcluster.go index deb019fdca1..c7b8f9450ea 100644 --- a/pkg/mock/mockcluster/mockcluster.go +++ b/pkg/mock/mockcluster/mockcluster.go @@ -135,7 +135,7 @@ func (mc *Cluster) AllocPeer(storeID uint64) (*metapb.Peer, error) { func (mc *Cluster) initRuleManager() { if mc.RuleManager == nil { - mc.RuleManager = placement.NewRuleManager(core.NewStorage(kv.NewMemoryKV())) + mc.RuleManager = placement.NewRuleManager(core.NewStorage(kv.NewMemoryKV(), nil, nil)) mc.RuleManager.Initialize(int(mc.GetReplicationConfig().MaxReplicas), mc.GetReplicationConfig().LocationLabels) } } diff --git a/server/cluster/coordinator.go b/server/cluster/coordinator.go index b2335565bf9..4f2422f05d4 100644 --- a/server/cluster/coordinator.go +++ b/server/cluster/coordinator.go @@ -590,7 +590,12 @@ func (c *coordinator) removeOptScheduler(o *config.PersistOptions, name string) for i, schedulerCfg := range v.Schedulers { // To create a temporary scheduler is just used to get scheduler's name decoder := schedule.ConfigSliceDecoder(schedulerCfg.Type, schedulerCfg.Args) - tmp, err := schedule.CreateScheduler(schedulerCfg.Type, schedule.NewOperatorController(c.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), decoder) + tmp, err := schedule.CreateScheduler( + schedulerCfg.Type, + schedule.NewOperatorController(c.ctx, nil, nil), + core.NewStorage(kv.NewMemoryKV(), nil, nil), + decoder, + ) if err != nil { return err } diff --git a/server/config/config.go b/server/config/config.go index fd8c81e3aa4..f5670f77111 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -26,6 +26,7 @@ import ( "sync" "time" + "github.com/tikv/pd/pkg/encryption" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/metricutil" @@ -121,7 +122,7 @@ type Config struct { // an election, thus minimizing disruptions. PreVote bool `toml:"enable-prevote"` - Security grpcutil.SecurityConfig `toml:"security" json:"security"` + Security SecurityConfig `toml:"security" json:"security"` LabelProperty LabelPropertyConfig `toml:"label-property" json:"label-property"` @@ -543,6 +544,8 @@ func (c *Config) Adjust(meta *toml.MetaData) error { c.ReplicationMode.adjust(configMetaData.Child("replication-mode")) + c.Security.Encryption.Adjust() + return nil } @@ -1364,3 +1367,9 @@ func (c *LocalTSOConfig) Validate() error { } return nil } + +// SecurityConfig is the configuration for TLS and encryption. +type SecurityConfig struct { + grpcutil.SecurityConfig + Encryption encryption.Config `toml:"encryption" json:"encryption"` +} diff --git a/server/core/region_storage.go b/server/core/region_storage.go index be6de0d037f..94b99debbff 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -19,9 +19,12 @@ import ( "sync" "time" + "github.com/gogo/protobuf/proto" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" + crypter "github.com/tikv/pd/pkg/encryption" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/server/encryption" "github.com/tikv/pd/server/kv" ) @@ -30,14 +33,15 @@ var dirtyFlushTick = time.Second // RegionStorage is used to save regions. type RegionStorage struct { *kv.LeveldbKV - mu sync.RWMutex - batchRegions map[string]*metapb.Region - batchSize int - cacheSize int - flushRate time.Duration - flushTime time.Time - regionStorageCtx context.Context - regionStorageCancel context.CancelFunc + encryptionKeyManager *encryption.KeyManager + mu sync.RWMutex + batchRegions map[string]*metapb.Region + batchSize int + cacheSize int + flushRate time.Duration + flushTime time.Time + regionStorageCtx context.Context + regionStorageCancel context.CancelFunc } const ( @@ -47,21 +51,26 @@ const ( defaultBatchSize = 100 ) -// NewRegionStorage returns a region storage that is used to save regions. -func NewRegionStorage(ctx context.Context, path string) (*RegionStorage, error) { +// newRegionStorage returns a region storage that is used to save regions. +func NewRegionStorage( + ctx context.Context, + path string, + encryptionKeyManager *encryption.KeyManager, +) (*RegionStorage, error) { levelDB, err := kv.NewLeveldbKV(path) if err != nil { return nil, err } regionStorageCtx, regionStorageCancel := context.WithCancel(ctx) s := &RegionStorage{ - LeveldbKV: levelDB, - batchSize: defaultBatchSize, - flushRate: defaultFlushRegionRate, - batchRegions: make(map[string]*metapb.Region, defaultBatchSize), - flushTime: time.Now().Add(defaultFlushRegionRate), - regionStorageCtx: regionStorageCtx, - regionStorageCancel: regionStorageCancel, + LeveldbKV: levelDB, + encryptionKeyManager: encryptionKeyManager, + batchSize: defaultBatchSize, + flushRate: defaultFlushRegionRate, + batchRegions: make(map[string]*metapb.Region, defaultBatchSize), + flushTime: time.Now().Add(defaultFlushRegionRate), + regionStorageCtx: regionStorageCtx, + regionStorageCancel: regionStorageCancel, } s.backgroundFlush() return s, nil @@ -96,6 +105,10 @@ func (s *RegionStorage) backgroundFlush() { // SaveRegion saves one region to storage. func (s *RegionStorage) SaveRegion(region *metapb.Region) error { + err := crypter.EncryptRegion(region, s.encryptionKeyManager) + if err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() if s.cacheSize < s.batchSize-1 { @@ -106,7 +119,7 @@ func (s *RegionStorage) SaveRegion(region *metapb.Region) error { return nil } s.batchRegions[regionPath(region.GetId())] = region - err := s.flush() + err = s.flush() if err != nil { return err @@ -118,7 +131,47 @@ func deleteRegion(kv kv.Base, region *metapb.Region) error { return kv.Remove(regionPath(region.GetId())) } -func loadRegions(kv kv.Base, f func(region *RegionInfo) []*RegionInfo) error { +func saveRegion( + kv kv.Base, + encryptionKeyManager *encryption.KeyManager, + region *metapb.Region, +) error { + err := crypter.EncryptRegion(region, encryptionKeyManager) + if err != nil { + return err + } + value, err := proto.Marshal(region) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStackByArgs() + } + return kv.Save(regionPath(region.GetId()), string(value)) +} + +func loadRegion( + kv kv.Base, + encryptionKeyManager *encryption.KeyManager, + region *metapb.Region, +) (ok bool, err error) { + value, err := kv.Load(regionPath(region.GetId())) + if err != nil { + return false, err + } + if value == "" { + return false, nil + } + err = proto.Unmarshal([]byte(value), region) + if err != nil { + return true, errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() + } + err = crypter.DecryptRegion(region, encryptionKeyManager) + return true, err +} + +func loadRegions( + kv kv.Base, + encryptionKeyManager *encryption.KeyManager, + f func(region *RegionInfo) []*RegionInfo, +) error { nextID := uint64(0) endKey := regionPath(math.MaxUint64) @@ -141,6 +194,9 @@ func loadRegions(kv kv.Base, f func(region *RegionInfo) []*RegionInfo) error { if err := region.Unmarshal([]byte(s)); err != nil { return errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() } + if err = crypter.DecryptRegion(region, encryptionKeyManager); err != nil { + return err + } nextID = region.GetId() + 1 overlaps := f(NewRegionInfo(region, nil)) diff --git a/server/core/storage.go b/server/core/storage.go index 70702297de4..fd6cc489c5c 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/server/encryption" "github.com/tikv/pd/server/kv" "go.etcd.io/etcd/clientv3" ) @@ -42,6 +43,7 @@ const ( replicationPath = "replication_mode" componentPath = "component" customScheduleConfigPath = "scheduler_config" + encryptionKeysPath = "encryption_keys" ) const ( @@ -52,25 +54,26 @@ const ( // Storage wraps all kv operations, keep it stateless. type Storage struct { kv.Base - regionStorage *RegionStorage - useRegionStorage int32 - regionLoaded int32 - mu sync.Mutex + regionStorage *RegionStorage + encryptionKeyManager *encryption.KeyManager + useRegionStorage int32 + regionLoaded int32 + mu sync.Mutex } // NewStorage creates Storage instance with Base. -func NewStorage(base kv.Base) *Storage { +func NewStorage( + base kv.Base, + regionStorage *RegionStorage, + encryptionKeyManager *encryption.KeyManager, +) *Storage { return &Storage{ - Base: base, + Base: base, + regionStorage: regionStorage, + encryptionKeyManager: encryptionKeyManager, } } -// SetRegionStorage sets the region storage. -func (s *Storage) SetRegionStorage(regionStorage *RegionStorage) *Storage { - s.regionStorage = regionStorage - return s -} - // GetRegionStorage gets the region storage. func (s *Storage) GetRegionStorage() *RegionStorage { return s.regionStorage @@ -107,6 +110,10 @@ func (s *Storage) storeRegionWeightPath(storeID uint64) string { return path.Join(schedulePath, "store_weight", fmt.Sprintf("%020d", storeID), "region") } +func (s *Storage) EncryptionKeysPath() string { + return path.Join(encryptionKeysPath, "keys") +} + // SaveScheduleConfig saves the config of scheduler. func (s *Storage) SaveScheduleConfig(scheduleName string, data []byte) error { configPath := path.Join(customScheduleConfigPath, scheduleName) @@ -150,31 +157,31 @@ func (s *Storage) DeleteStore(store *metapb.Store) error { return s.Remove(s.storePath(store.GetId())) } -// LoadRegion loads one region from storage. -func (s *Storage) LoadRegion(regionID uint64, region *metapb.Region) (bool, error) { +// LoadRegion loads one regoin from storage. +func (s *Storage) LoadRegion(regionID uint64, region *metapb.Region) (ok bool, err error) { if atomic.LoadInt32(&s.useRegionStorage) > 0 { - return loadProto(s.regionStorage, regionPath(regionID), region) + return loadRegion(s.regionStorage, s.encryptionKeyManager, region) } - return loadProto(s.Base, regionPath(regionID), region) + return loadRegion(s.Base, s.encryptionKeyManager, region) } // LoadRegions loads all regions from storage to RegionsInfo. func (s *Storage) LoadRegions(f func(region *RegionInfo) []*RegionInfo) error { if atomic.LoadInt32(&s.useRegionStorage) > 0 { - return loadRegions(s.regionStorage, f) + return loadRegions(s.regionStorage, s.encryptionKeyManager, f) } - return loadRegions(s.Base, f) + return loadRegions(s.Base, s.encryptionKeyManager, f) } // LoadRegionsOnce loads all regions from storage to RegionsInfo.Only load one time from regionStorage. func (s *Storage) LoadRegionsOnce(f func(region *RegionInfo) []*RegionInfo) error { if atomic.LoadInt32(&s.useRegionStorage) == 0 { - return loadRegions(s.Base, f) + return loadRegions(s.Base, s.encryptionKeyManager, f) } s.mu.Lock() defer s.mu.Unlock() if s.regionLoaded == 0 { - if err := loadRegions(s.regionStorage, f); err != nil { + if err := loadRegions(s.regionStorage, s.encryptionKeyManager, f); err != nil { return err } s.regionLoaded = 1 @@ -187,7 +194,7 @@ func (s *Storage) SaveRegion(region *metapb.Region) error { if atomic.LoadInt32(&s.useRegionStorage) > 0 { return s.regionStorage.SaveRegion(region) } - return saveProto(s.Base, regionPath(region.GetId()), region) + return saveRegion(s.Base, s.encryptionKeyManager, region) } // DeleteRegion deletes one region from storage. @@ -257,7 +264,7 @@ func (s *Storage) LoadRuleGroups(f func(k, v string)) error { func (s *Storage) SaveJSON(prefix, key string, data interface{}) error { value, err := json.Marshal(data) if err != nil { - return errs.ErrJSONMarshal.Wrap(err).GenWithStackByArgs() + return err } return s.Save(path.Join(prefix, key), string(value)) } @@ -401,7 +408,13 @@ func (s *Storage) Flush() error { // Close closes the s. func (s *Storage) Close() error { if s.regionStorage != nil { - return s.regionStorage.Close() + err := s.regionStorage.Close() + if err != nil { + return err + } + } + if s.encryptionKeyManager != nil { + s.encryptionKeyManager.Close() } return nil } diff --git a/server/encryption/key_manager.go b/server/encryption/key_manager.go new file mode 100644 index 00000000000..16f2a9c73bf --- /dev/null +++ b/server/encryption/key_manager.go @@ -0,0 +1,54 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package encryption + +import ( + "github.com/pingcap/kvproto/pkg/encryptionpb" + lib "github.com/tikv/pd/pkg/encryption" + "github.com/tikv/pd/server/election" + "github.com/tikv/pd/server/kv" +) + +// KeyManager maintains the list to encryption keys. It handles encryption key generation and +// rotation, persisting and loading encryption keys. +type KeyManager struct{} + +// NewKeyManager creates a new key manager. +func NewKeyManager(kv kv.Base, config *lib.Config) (*KeyManager, error) { + // TODO: Implement + return &KeyManager{}, nil +} + +// GetCurrentKey get the current encryption key. The key is nil if encryption is not enabled. +func (m *KeyManager) GetCurrentKey() (keyID uint64, key *encryptionpb.DataKey, err error) { + // TODO: Implement + return 0, nil, nil +} + +// GetKey get the encryption key with the specific key id. +func (m *KeyManager) GetKey(keyID uint64) (key *encryptionpb.DataKey, err error) { + // TODO: Implement + return nil, nil +} + +// SetLeadership sets the PD leadership of the current node. PD leader is responsible to update +// encryption keys, e.g. key rotation. +func (m *KeyManager) SetLeadership(leadership *election.Leadership) { + // TODO: Implement +} + +// Close close the key manager on PD server shutdown +func (m *KeyManager) Close() { + // TODO: Implement +} diff --git a/server/server.go b/server/server.go index 3f9496873fc..871ae4c33b3 100644 --- a/server/server.go +++ b/server/server.go @@ -47,6 +47,7 @@ import ( "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/encryption" "github.com/tikv/pd/server/id" "github.com/tikv/pd/server/kv" "github.com/tikv/pd/server/member" @@ -115,6 +116,8 @@ type Server struct { // store, region and peer, because we just need // a unique ID. idAllocator *id.AllocatorImpl + // for encryption + encryptionKeyManager *encryption.KeyManager // for storage operation. storage *core.Storage // for basicCluster operation. @@ -357,12 +360,18 @@ func (s *Server) startServer(ctx context.Context) error { return err } kvBase := kv.NewEtcdKVBase(s.client, s.rootPath) + encryptionKeyManager, err := encryption.NewKeyManager(kvBase, &s.cfg.Security.Encryption) + if err != nil { + return err + } + s.encryptionKeyManager = encryptionKeyManager path := filepath.Join(s.cfg.DataDir, "region-meta") - regionStorage, err := core.NewRegionStorage(ctx, path) + regionStorage, err := core.NewRegionStorage(ctx, path, encryptionKeyManager) if err != nil { return err } - s.storage = core.NewStorage(kvBase).SetRegionStorage(regionStorage) + + s.storage = core.NewStorage(kvBase, regionStorage, encryptionKeyManager) s.basicCluster = core.NewBasicCluster() s.cluster = cluster.NewRaftCluster(ctx, s.GetClusterRootPath(), s.clusterID, syncer.NewRegionSyncer(s), s.client, s.httpClient) s.hbStreams = hbstream.NewHeartbeatStreams(ctx, s.clusterID, s.cluster) @@ -947,7 +956,7 @@ func (s *Server) GetClusterVersion() semver.Version { // GetSecurityConfig get the security config. func (s *Server) GetSecurityConfig() *grpcutil.SecurityConfig { - return &s.cfg.Security + return &s.cfg.Security.SecurityConfig } // GetServerRootPath returns the server root path. @@ -1146,6 +1155,9 @@ func (s *Server) campaignLeader() { log.Error("failed to reload configuration", errs.ZapError(err)) return } + + s.encryptionKeyManager.SetLeadership(s.member.GetLeadership()) + // Try to create raft cluster. if err := s.createRaftCluster(); err != nil { log.Error("failed to create raft cluster", errs.ZapError(err)) From b6f7ba0b1ba0f2b13b14296d18fbb216d7d13ac0 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 17 Sep 2020 08:04:35 +0800 Subject: [PATCH 02/37] update errno and config template Signed-off-by: Yi Wu --- conf/config.toml | 4 ++-- go.mod | 1 - pkg/encryption/config.go | 10 ++++++---- pkg/errs/errno.go | 1 + 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/conf/config.toml b/conf/config.toml index d00fbf3ba2f..88d32cb1f80 100644 --- a/conf/config.toml +++ b/conf/config.toml @@ -36,8 +36,8 @@ cert-allowed-cn = ["example.com"] ## Encryption method to use for PD data. One of "plaintext", "aes128-ctr", "aes192-ctr" and "aes256-ctr". ## Defaults to "plaintext" if not set. # data-encryption-method = "plaintext" -## Specifies how often PD rotates data encryption key. -# data-key-rotation-period = "7d" +## Specifies how often PD rotates data encryption key. Default is 7 days. +# data-key-rotation-period = "168h" ## Specifies master key if encryption is enabled. There are three types of master key: ## diff --git a/go.mod b/go.mod index c2bbd32bc19..da43ee53bec 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,6 @@ require ( github.com/pingcap/kvproto v0.0.0-20200916031750-f9473f2c5379 github.com/pingcap/log v0.0.0-20200511115504-543df19646ad github.com/pingcap/sysutil v0.0.0-20200715082929-4c47bcac246a - github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.0.0 github.com/prometheus/common v0.4.1 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/pkg/encryption/config.go b/pkg/encryption/config.go index 6c048b18e96..935cb4de428 100644 --- a/pkg/encryption/config.go +++ b/pkg/encryption/config.go @@ -17,7 +17,7 @@ import ( "time" "github.com/pingcap/kvproto/pkg/encryptionpb" - "github.com/pkg/errors" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/typeutil" ) @@ -55,7 +55,8 @@ func (c *Config) Adjust() error { if c.DataKeyRotationPeriod.Duration == 0 { duration, err := time.ParseDuration(defaultDataKeyRotationPeriod) if err != nil { - return errors.Wrapf(err, "fail to parse default value of data-key-rotation-period %s", + return errs.ErrEncryptionInvalidConfig.Wrap(err).GenWithStack( + "fail to parse default value of data-key-rotation-period %s", defaultDataKeyRotationPeriod) } c.DataKeyRotationPeriod.Duration = duration @@ -82,7 +83,7 @@ func (c *Config) GetMethod() (encryptionpb.EncryptionMethod, error) { return encryptionpb.EncryptionMethod_AES256_CTR, nil default: return encryptionpb.EncryptionMethod_UNKNOWN, - errors.Errorf("invalid encryption method %s", c.DataEncryptionMethod) + errs.ErrEncryptionInvalidMethod.GenWithStack("unknown method") } } @@ -114,7 +115,8 @@ func (c *Config) GetMasterKey() (*encryptionpb.MasterKey, error) { }, }, nil default: - return nil, errors.Errorf("unrecognized encryption master key type: %s", c.MasterKey.Type) + return nil, errs.ErrEncryptionInvalidConfig.GenWithStack( + "unrecognized encryption master key type: %s", c.MasterKey.Type) } } diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index af438f8bd84..8e5ea54c3cf 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -267,6 +267,7 @@ var ( // encryption var ( ErrEncryptionInvalidMethod = errors.Normalize("invalid encryption method", errors.RFCCodeText("PD:encryption:ErrEncryptionInvalidMethod")) + ErrEncryptionInvalidConfig = errors.Normalize("invalid config", errors.RFCCodeText("PD:encryption:ErrEncryptionInvalidConfig")) ErrEncryptionGenerateIV = errors.Normalize("fail to generate iv", errors.RFCCodeText("PD:encryption:ErrEncryptionGenerateIV")) ErrEncryptionNewDataKey = errors.Normalize("fail to generate data key", errors.RFCCodeText("PD:encryption:ErrEncryptionNewDataKey")) ErrEncryptionGCMEncrypt = errors.Normalize("GCM encryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionGCMEncrypt")) From 0bf03047700710f9adc76525daa5e4344bdaf043 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 17 Sep 2020 08:07:15 +0800 Subject: [PATCH 03/37] fix typo Signed-off-by: Yi Wu --- server/core/storage.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/core/storage.go b/server/core/storage.go index fd6cc489c5c..e43fa852c23 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -157,7 +157,7 @@ func (s *Storage) DeleteStore(store *metapb.Store) error { return s.Remove(s.storePath(store.GetId())) } -// LoadRegion loads one regoin from storage. +// LoadRegion loads one region from storage. func (s *Storage) LoadRegion(regionID uint64, region *metapb.Region) (ok bool, err error) { if atomic.LoadInt32(&s.useRegionStorage) > 0 { return loadRegion(s.regionStorage, s.encryptionKeyManager, region) From 23954dba4d47f5c4a0da0682432db570f538e62e Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 17 Sep 2020 11:22:37 +0800 Subject: [PATCH 04/37] fix tests Signed-off-by: Yi Wu --- pkg/component/manager_test.go | 2 +- server/cluster/cluster_test.go | 30 +++++++++----- server/core/storage_test.go | 18 ++++----- server/replication/replication_mode_test.go | 10 ++--- .../schedule/placement/rule_manager_test.go | 2 +- server/schedulers/balance_test.go | 40 +++++++++---------- 6 files changed, 56 insertions(+), 46 deletions(-) diff --git a/pkg/component/manager_test.go b/pkg/component/manager_test.go index 3c02a162729..fe360294c1c 100644 --- a/pkg/component/manager_test.go +++ b/pkg/component/manager_test.go @@ -31,7 +31,7 @@ var _ = Suite(&testManagerSuite{}) type testManagerSuite struct{} func (s *testManagerSuite) TestManager(c *C) { - m := NewManager(core.NewStorage(kv.NewMemoryKV())) + m := NewManager(core.NewStorage(kv.NewMemoryKV(), nil, nil)) // register legal address c.Assert(m.Register("c1", "127.0.0.1:1"), IsNil) c.Assert(m.Register("c1", "127.0.0.1:2"), IsNil) diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 7c280226a95..55e7aa76af8 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -46,7 +46,8 @@ type testClusterInfoSuite struct{} func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) n, np := uint64(3), uint64(3) stores := newTestStores(n) @@ -95,7 +96,8 @@ func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { func (s *testClusterInfoSuite) TestFilterUnhealthyStore(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) stores := newTestStores(3) for _, store := range stores { @@ -127,7 +129,8 @@ func (s *testClusterInfoSuite) TestFilterUnhealthyStore(c *C) { func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) n, np := uint64(3), uint64(3) @@ -352,7 +355,8 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { func (s *testClusterInfoSuite) TestRegionFlowChanged(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} processRegions := func(regions []*core.RegionInfo) { for _, r := range regions { @@ -380,7 +384,8 @@ func (s *testClusterInfoSuite) TestRegionFlowChanged(c *C) { func (s *testClusterInfoSuite) TestConcurrentRegionHeartbeat(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} regions = core.SplitRegions(regions) @@ -441,7 +446,8 @@ func heartbeatRegions(c *C, cluster *RaftCluster, regions []*core.RegionInfo) { func (s *testClusterInfoSuite) TestHeartbeatSplit(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) // 1: [nil, nil) region1 := core.NewRegionInfo(&metapb.Region{Id: 1, RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) @@ -480,7 +486,8 @@ func (s *testClusterInfoSuite) TestHeartbeatSplit(c *C) { func (s *testClusterInfoSuite) TestRegionSplitAndMerge(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} @@ -587,7 +594,8 @@ func (s *testRegionsInfoSuite) Test(c *C) { regions := newTestRegions(n, np) _, opts, err := newTestScheduleConfig() c.Assert(err, IsNil) - tc := newTestRaftCluster(mockid.NewIDAllocator(), opts, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + tc := newTestRaftCluster( + mockid.NewIDAllocator(), opts, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) cache := tc.core.Regions for i := uint64(0); i < n; i++ { @@ -697,7 +705,8 @@ type testGetStoresSuite struct { func (s *testGetStoresSuite) SetUpSuite(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) s.cluster = cluster stores := newTestStores(200) @@ -730,7 +739,8 @@ func newTestScheduleConfig() (*config.ScheduleConfig, *config.PersistOptions, er } func newTestCluster(opt *config.PersistOptions) *testCluster { - rc := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + rc := newTestRaftCluster( + mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) return &testCluster{RaftCluster: rc} } diff --git a/server/core/storage_test.go b/server/core/storage_test.go index 8ccbbffda6f..2f2a7e05762 100644 --- a/server/core/storage_test.go +++ b/server/core/storage_test.go @@ -34,7 +34,7 @@ type testKVSuite struct { } func (s *testKVSuite) TestBasic(c *C) { - storage := NewStorage(kv.NewMemoryKV()) + storage := NewStorage(kv.NewMemoryKV(), nil, nil) c.Assert(storage.storePath(123), Equals, "raft/s/00000000000000000123") c.Assert(regionPath(123), Equals, "raft/r/00000000000000000123") @@ -93,7 +93,7 @@ func mustSaveStores(c *C, s *Storage, n int) []*metapb.Store { } func (s *testKVSuite) TestLoadStores(c *C) { - storage := NewStorage(kv.NewMemoryKV()) + storage := NewStorage(kv.NewMemoryKV(), nil, nil) cache := NewStoresInfo() n := 10 @@ -107,7 +107,7 @@ func (s *testKVSuite) TestLoadStores(c *C) { } func (s *testKVSuite) TestStoreWeight(c *C) { - storage := NewStorage(kv.NewMemoryKV()) + storage := NewStorage(kv.NewMemoryKV(), nil, nil) cache := NewStoresInfo() const n = 3 @@ -138,7 +138,7 @@ func mustSaveRegions(c *C, s *Storage, n int) []*metapb.Region { } func (s *testKVSuite) TestLoadRegions(c *C) { - storage := NewStorage(kv.NewMemoryKV()) + storage := NewStorage(kv.NewMemoryKV(), nil, nil) cache := NewRegionsInfo() n := 10 @@ -152,7 +152,7 @@ func (s *testKVSuite) TestLoadRegions(c *C) { } func (s *testKVSuite) TestLoadRegionsToCache(c *C) { - storage := NewStorage(kv.NewMemoryKV()) + storage := NewStorage(kv.NewMemoryKV(), nil, nil) cache := NewRegionsInfo() n := 10 @@ -171,7 +171,7 @@ func (s *testKVSuite) TestLoadRegionsToCache(c *C) { } func (s *testKVSuite) TestLoadRegionsExceedRangeLimit(c *C) { - storage := NewStorage(&KVWithMaxRangeLimit{Base: kv.NewMemoryKV(), rangeLimit: 500}) + storage := NewStorage(&KVWithMaxRangeLimit{Base: kv.NewMemoryKV(), rangeLimit: 500}, nil, nil) cache := NewRegionsInfo() n := 1000 @@ -184,7 +184,7 @@ func (s *testKVSuite) TestLoadRegionsExceedRangeLimit(c *C) { } func (s *testKVSuite) TestLoadGCSafePoint(c *C) { - storage := NewStorage(kv.NewMemoryKV()) + storage := NewStorage(kv.NewMemoryKV(), nil, nil) testData := []uint64{0, 1, 2, 233, 2333, 23333333333, math.MaxUint64} r, e := storage.LoadGCSafePoint() @@ -201,7 +201,7 @@ func (s *testKVSuite) TestLoadGCSafePoint(c *C) { func (s *testKVSuite) TestSaveServiceGCSafePoint(c *C) { mem := kv.NewMemoryKV() - storage := NewStorage(mem) + storage := NewStorage(mem, nil, nil) expireAt := time.Now().Add(100 * time.Second).Unix() serviceSafePoints := []*ServiceSafePoint{ {"1", expireAt, 1}, @@ -233,7 +233,7 @@ func (s *testKVSuite) TestSaveServiceGCSafePoint(c *C) { func (s *testKVSuite) TestLoadMinServiceGCSafePoint(c *C) { mem := kv.NewMemoryKV() - storage := NewStorage(mem) + storage := NewStorage(mem, nil, nil) expireAt := time.Now().Add(1000 * time.Second).Unix() serviceSafePoints := []*ServiceSafePoint{ {"1", 0, 1}, diff --git a/server/replication/replication_mode_test.go b/server/replication/replication_mode_test.go index 7b4c2138524..b85f5d31002 100644 --- a/server/replication/replication_mode_test.go +++ b/server/replication/replication_mode_test.go @@ -37,7 +37,7 @@ var _ = Suite(&testReplicationMode{}) type testReplicationMode struct{} func (s *testReplicationMode) TestInitial(c *C) { - store := core.NewStorage(kv.NewMemoryKV()) + store := core.NewStorage(kv.NewMemoryKV(), nil, nil) conf := config.ReplicationModeConfig{ReplicationMode: modeMajority} cluster := mockcluster.NewCluster(config.NewTestOptions()) rep, err := NewReplicationModeManager(conf, store, cluster, nil) @@ -67,7 +67,7 @@ func (s *testReplicationMode) TestInitial(c *C) { } func (s *testReplicationMode) TestStatus(c *C) { - store := core.NewStorage(kv.NewMemoryKV()) + store := core.NewStorage(kv.NewMemoryKV(), nil, nil) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "dr-label", WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, @@ -137,7 +137,7 @@ func (rep *mockFileReplicator) ReplicateFileToAllMembers(context.Context, string } func (s *testReplicationMode) TestStateSwitch(c *C) { - store := core.NewStorage(kv.NewMemoryKV()) + store := core.NewStorage(kv.NewMemoryKV(), nil, nil) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", Primary: "zone1", @@ -240,7 +240,7 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { } func (s *testReplicationMode) TestAsynctimeout(c *C) { - store := core.NewStorage(kv.NewMemoryKV()) + store := core.NewStorage(kv.NewMemoryKV(), nil, nil) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", Primary: "zone1", @@ -292,7 +292,7 @@ func (s *testReplicationMode) TestRecoverProgress(c *C) { regionScanBatchSize = 10 regionMinSampleSize = 5 - store := core.NewStorage(kv.NewMemoryKV()) + store := core.NewStorage(kv.NewMemoryKV(), nil, nil) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", Primary: "zone1", diff --git a/server/schedule/placement/rule_manager_test.go b/server/schedule/placement/rule_manager_test.go index b5806a43633..408afae30c1 100644 --- a/server/schedule/placement/rule_manager_test.go +++ b/server/schedule/placement/rule_manager_test.go @@ -30,7 +30,7 @@ type testManagerSuite struct { } func (s *testManagerSuite) SetUpTest(c *C) { - s.store = core.NewStorage(kv.NewMemoryKV()) + s.store = core.NewStorage(kv.NewMemoryKV(), nil, nil) var err error s.manager = NewRuleManager(s.store) err = s.manager.Initialize(3, []string{"zone", "rack", "host"}) diff --git a/server/schedulers/balance_test.go b/server/schedulers/balance_test.go index caa61245e1b..225365d7bfc 100644 --- a/server/schedulers/balance_test.go +++ b/server/schedulers/balance_test.go @@ -185,7 +185,7 @@ func (s *testBalanceLeaderSchedulerSuite) SetUpTest(c *C) { s.opt = config.NewTestOptions() s.tc = mockcluster.NewCluster(s.opt) s.oc = schedule.NewOperatorController(s.ctx, s.tc, nil) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) s.lb = lb } @@ -502,28 +502,28 @@ func (s *testBalanceLeaderRangeSchedulerSuite) TestSingleRangeBalance(c *C) { s.tc.UpdateStoreLeaderWeight(3, 1) s.tc.UpdateStoreLeaderWeight(4, 2) s.tc.AddLeaderRegionWithRange(1, "a", "g", 1, 2, 3, 4) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) ops := lb.Schedule(s.tc) c.Assert(ops, NotNil) c.Assert(ops, HasLen, 1) c.Assert(ops[0].Counters, HasLen, 5) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"h", "n"})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"h", "n"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", "f"})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", "f"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "a"})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "a"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"g", ""})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"g", ""})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "f"})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "f"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", ""})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", ""})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) } @@ -542,7 +542,7 @@ func (s *testBalanceLeaderRangeSchedulerSuite) TestMultiRangeBalance(c *C) { s.tc.UpdateStoreLeaderWeight(3, 1) s.tc.UpdateStoreLeaderWeight(4, 2) s.tc.AddLeaderRegionWithRange(1, "a", "g", 1, 2, 3, 4) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "g", "o", "t"})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "g", "o", "t"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc)[0].RegionID(), Equals, uint64(1)) s.tc.RemoveRegion(s.tc.GetRegion(1)) @@ -580,7 +580,7 @@ func (s *testBalanceRegionSchedulerSuite) TestBalance(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) opt.SetMaxReplicas(1) @@ -616,7 +616,7 @@ func (s *testBalanceRegionSchedulerSuite) TestReplicas3(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) s.checkReplica3(c, tc, opt, sb) @@ -679,7 +679,7 @@ func (s *testBalanceRegionSchedulerSuite) TestReplicas5(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) s.checkReplica5(c, tc, opt, sb) @@ -771,7 +771,7 @@ func (s *testBalanceRegionSchedulerSuite) TestBalance1(c *C) { core.SetApproximateKeys(200), ) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) tc.AddRegionStore(1, 11) @@ -814,7 +814,7 @@ func (s *testBalanceRegionSchedulerSuite) TestStoreWeight(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) opt.SetMaxReplicas(1) @@ -842,7 +842,7 @@ func (s *testBalanceRegionSchedulerSuite) TestReplacePendingRegion(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) s.checkReplacePendingRegion(c, tc, opt, sb) @@ -856,7 +856,7 @@ func (s *testBalanceRegionSchedulerSuite) TestOpInfluence(c *C) { tc.DisableFeature(versioninfo.JointConsensus) stream := hbstream.NewTestHeartbeatStreams(s.ctx, tc.ID, tc, false /* no need to run */) oc := schedule.NewOperatorController(s.ctx, tc, stream) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) opt.SetMaxReplicas(1) // Add stores 1,2,3,4. @@ -912,7 +912,7 @@ func (s *testRandomMergeSchedulerSuite) TestMerge(c *C) { stream := hbstream.NewTestHeartbeatStreams(ctx, tc.ID, tc, true /* need to run */) oc := schedule.NewOperatorController(ctx, tc, stream) - mb, err := schedule.CreateScheduler(RandomMergeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(RandomMergeType, []string{"", ""})) + mb, err := schedule.CreateScheduler(RandomMergeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(RandomMergeType, []string{"", ""})) c.Assert(err, IsNil) tc.AddRegionStore(1, 4) @@ -1001,7 +1001,7 @@ func (s *testScatterRangeLeaderSuite) TestBalance(c *C) { } oc := schedule.NewOperatorController(s.ctx, nil, nil) - hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) + hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) c.Assert(err, IsNil) limit := 0 for { @@ -1027,7 +1027,7 @@ func (s *testScatterRangeLeaderSuite) TestConcurrencyUpdateConfig(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) oc := schedule.NewOperatorController(s.ctx, nil, nil) - hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) + hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) sche := hb.(*scatterRangeScheduler) c.Assert(err, IsNil) ch := make(chan struct{}) @@ -1099,7 +1099,7 @@ func (s *testScatterRangeLeaderSuite) TestBalanceWhenRegionNotHeartbeat(c *C) { } oc := schedule.NewOperatorController(s.ctx, nil, nil) - hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_09", "t"})) + hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_09", "t"})) c.Assert(err, IsNil) limit := 0 From 2bad0b960f09ca2196d54ac2e5629b8faf4f713e Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 17 Sep 2020 11:49:10 +0800 Subject: [PATCH 05/37] fix tests Signed-off-by: Yi Wu --- server/cluster/cluster_worker_test.go | 4 ++-- server/cluster/coordinator_test.go | 8 ++++---- server/config/config_test.go | 4 ++-- server/schedulers/hot_test.go | 18 +++++++++--------- server/statistics/region_collection_test.go | 2 +- tests/server/cluster/cluster_test.go | 12 ++++++------ 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/server/cluster/cluster_worker_test.go b/server/cluster/cluster_worker_test.go index 837640959ec..03790368993 100644 --- a/server/cluster/cluster_worker_test.go +++ b/server/cluster/cluster_worker_test.go @@ -31,7 +31,7 @@ type testClusterWorkerSuite struct{} func (s *testClusterWorkerSuite) TestReportSplit(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) left := &metapb.Region{Id: 1, StartKey: []byte("a"), EndKey: []byte("b")} right := &metapb.Region{Id: 2, StartKey: []byte("b"), EndKey: []byte("c")} _, err = cluster.HandleReportSplit(&pdpb.ReportSplitRequest{Left: left, Right: right}) @@ -43,7 +43,7 @@ func (s *testClusterWorkerSuite) TestReportSplit(c *C) { func (s *testClusterWorkerSuite) TestReportBatchSplit(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) regions := []*metapb.Region{ {Id: 1, StartKey: []byte(""), EndKey: []byte("a")}, {Id: 2, StartKey: []byte("a"), EndKey: []byte("b")}, diff --git a/server/cluster/coordinator_test.go b/server/cluster/coordinator_test.go index b7e9959e1c8..bd883d07910 100644 --- a/server/cluster/coordinator_test.go +++ b/server/cluster/coordinator_test.go @@ -601,12 +601,12 @@ func (s *testCoordinatorSuite) TestAddScheduler(c *C) { c.Assert(tc.addLeaderRegion(3, 3, 1, 2), IsNil) oc := co.opController - gls, err := schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"0"})) + gls, err := schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"0"})) c.Assert(err, IsNil) c.Assert(co.addScheduler(gls), NotNil) c.Assert(co.removeScheduler(gls.GetName()), NotNil) - gls, err = schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"})) + gls, err = schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"})) c.Assert(err, IsNil) c.Assert(co.addScheduler(gls), IsNil) @@ -1014,7 +1014,7 @@ func (s *testScheduleControllerSuite) TestController(c *C) { c.Assert(tc.addLeaderRegion(1, 1), IsNil) c.Assert(tc.addLeaderRegion(2, 2), IsNil) - scheduler, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) + scheduler, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) lb := &mockLimitScheduler{ Scheduler: scheduler, @@ -1098,7 +1098,7 @@ func (s *testScheduleControllerSuite) TestInterval(c *C) { _, co, cleanup := prepare(nil, nil, nil, c) defer cleanup() - lb, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, co.opController, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, co.opController, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) sc := newScheduleController(co, lb) diff --git a/server/config/config_test.go b/server/config/config_test.go index 9464cd0247c..b63a164f5ff 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -60,7 +60,7 @@ func (s *testConfigSuite) TestBadFormatJoinAddr(c *C) { func (s *testConfigSuite) TestReloadConfig(c *C) { opt, err := newTestScheduleOption() c.Assert(err, IsNil) - storage := core.NewStorage(kv.NewMemoryKV()) + storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) scheduleCfg := opt.GetScheduleConfig() scheduleCfg.MaxSnapshotCount = 10 opt.SetMaxReplicas(5) @@ -100,7 +100,7 @@ func (s *testConfigSuite) TestReloadUpgrade(c *C) { Schedule: *opt.GetScheduleConfig(), Replication: *opt.GetReplicationConfig(), } - storage := core.NewStorage(kv.NewMemoryKV()) + storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) c.Assert(storage.SaveConfig(old), IsNil) newOpt, err := newTestScheduleOption() diff --git a/server/schedulers/hot_test.go b/server/schedulers/hot_test.go index aad95f043d4..8580fed6f82 100644 --- a/server/schedulers/hot_test.go +++ b/server/schedulers/hot_test.go @@ -52,7 +52,7 @@ func (s *testHotSchedulerSuite) TestGCPendingOpInfos(c *C) { tc.PutStoreWithLabels(id) } - sche, err := schedule.CreateScheduler(HotRegionType, schedule.NewOperatorController(ctx, tc, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigJSONDecoder([]byte("null"))) + sche, err := schedule.CreateScheduler(HotRegionType, schedule.NewOperatorController(ctx, tc, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigJSONDecoder([]byte("null"))) c.Assert(err, IsNil) hb := sche.(*hotScheduler) @@ -135,7 +135,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestByteRateOnly(c *C) { tc.SetMaxReplicas(3) tc.SetLocationLabels([]string{"zone", "host"}) tc.DisableFeature(versioninfo.JointConsensus) - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) tc.SetHotRegionCacheHitsThreshold(0) @@ -310,7 +310,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestWithKeyRate(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) hb.(*hotScheduler).conf.SetDstToleranceRatio(1) hb.(*hotScheduler).conf.SetSrcToleranceRatio(1) @@ -364,7 +364,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestLeader(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) tc := mockcluster.NewCluster(opt) @@ -405,7 +405,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestWithPendingInfluence(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) for i := 0; i < 2; i++ { // 0: byte rate @@ -491,7 +491,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestWithRuleEnabled(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) tc.SetEnablePlacementRules(true) - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) tc.SetHotRegionCacheHitsThreshold(0) key, err := hex.DecodeString("") @@ -567,7 +567,7 @@ func (s *testHotReadRegionSchedulerSuite) TestByteRateOnly(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) tc.DisableFeature(versioninfo.JointConsensus) - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) tc.SetHotRegionCacheHitsThreshold(0) @@ -670,7 +670,7 @@ func (s *testHotReadRegionSchedulerSuite) TestWithKeyRate(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) hb.(*hotScheduler).conf.SetSrcToleranceRatio(1) hb.(*hotScheduler).conf.SetDstToleranceRatio(1) @@ -722,7 +722,7 @@ func (s *testHotReadRegionSchedulerSuite) TestWithPendingInfluence(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) // For test hb.(*hotScheduler).conf.GreatDecRatio = 0.99 diff --git a/server/statistics/region_collection_test.go b/server/statistics/region_collection_test.go index 53beaf99ca2..0c3b9e26450 100644 --- a/server/statistics/region_collection_test.go +++ b/server/statistics/region_collection_test.go @@ -31,7 +31,7 @@ type testRegionStatisticsSuite struct { } func (t *testRegionStatisticsSuite) SetUpTest(c *C) { - t.store = core.NewStorage(kv.NewMemoryKV()) + t.store = core.NewStorage(kv.NewMemoryKV(), nil, nil) var err error t.manager = placement.NewRuleManager(t.store) err = t.manager.Initialize(3, []string{"zone", "rack", "host"}) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 39071900a54..359e840b0a1 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -460,7 +460,7 @@ func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) var stores []*metapb.Store id := leaderServer.GetAllocator() for _, addr := range storeAddrs { @@ -607,7 +607,7 @@ func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { // PUT GET failed oldStorage := svr.GetStorage() - svr.SetStorage(core.NewStorage(&testErrorKV{})) + svr.SetStorage(core.NewStorage(&testErrorKV{}), nil, nil) replicationCfg.MaxReplicas = 7 scheduleCfg.MaxSnapshotCount = 20 pdServerCfg.UseRegionStorage = false @@ -626,7 +626,7 @@ func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { svr.SetStorage(oldStorage) c.Assert(svr.SetReplicationConfig(*replicationCfg), IsNil) - svr.SetStorage(core.NewStorage(&testErrorKV{})) + svr.SetStorage(core.NewStorage(&testErrorKV{}), nil, nil) c.Assert(svr.DeleteLabelProperty(typ, labelKey, labelValue), NotNil) c.Assert(persistOptions.GetLabelPropertyConfig()[typ][0].Key, Equals, "testKey") @@ -894,7 +894,7 @@ func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV()), nil, nil) id := leaderServer.GetAllocator() for _, addr := range storeAddrs { storeID, err := id.Alloc() @@ -981,7 +981,7 @@ func (s *clusterTestSuite) TestUpgradeStoreLimit(c *C) { bootstrapCluster(c, clusterID, grpcPDClient, "127.0.0.1:0") rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV()), nil, nil) store := newMetaStore(1, "127.0.1.1:0", "4.0.0", metapb.StoreState_Up, "test/store1") _, err = putStore(c, grpcPDClient, clusterID, store) c.Assert(err, IsNil) @@ -1039,7 +1039,7 @@ func (s *clusterTestSuite) TestStaleTermHeartbeat(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV()), nil, nil) var peers []*metapb.Peer id := leaderServer.GetAllocator() for _, addr := range storeAddrs { From f88d26a57729818e072a316a0b01e272032a7bcd Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 17 Sep 2020 11:55:27 +0800 Subject: [PATCH 06/37] fix tests Signed-off-by: Yi Wu --- server/schedulers/scheduler_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/server/schedulers/scheduler_test.go b/server/schedulers/scheduler_test.go index 1e030126302..8bfbcfe5962 100644 --- a/server/schedulers/scheduler_test.go +++ b/server/schedulers/scheduler_test.go @@ -51,7 +51,7 @@ func (s *testShuffleLeaderSuite) TestShuffle(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) - sl, err := schedule.CreateScheduler(ShuffleLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ShuffleLeaderType, []string{"", ""})) + sl, err := schedule.CreateScheduler(ShuffleLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ShuffleLeaderType, []string{"", ""})) c.Assert(err, IsNil) c.Assert(sl.Schedule(tc), IsNil) @@ -93,7 +93,7 @@ func (s *testBalanceAdjacentRegionSuite) TestBalance(c *C) { tc := mockcluster.NewCluster(opt) tc.DisableFeature(versioninfo.JointConsensus) - sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(AdjacentRegionType, []string{"32", "2"})) + sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(AdjacentRegionType, []string{"32", "2"})) c.Assert(err, IsNil) c.Assert(sc.(*balanceAdjacentRegionScheduler).conf.LeaderLimit, Equals, uint64(32)) @@ -161,7 +161,7 @@ func (s *testBalanceAdjacentRegionSuite) TestNoNeedToBalance(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) - sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(AdjacentRegionType, nil)) + sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(AdjacentRegionType, nil)) c.Assert(err, IsNil) c.Assert(sc.Schedule(tc), IsNil) @@ -199,7 +199,7 @@ func (s *testRejectLeaderSuite) TestRejectLeader(c *C) { // The label scheduler transfers leader out of store1. oc := schedule.NewOperatorController(ctx, nil, nil) - sl, err := schedule.CreateScheduler(LabelType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(LabelType, []string{"", ""})) + sl, err := schedule.CreateScheduler(LabelType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(LabelType, []string{"", ""})) c.Assert(err, IsNil) op := sl.Schedule(tc) testutil.CheckTransferLeaderFrom(c, op[0], operator.OpLeader, 1) @@ -211,13 +211,13 @@ func (s *testRejectLeaderSuite) TestRejectLeader(c *C) { // As store3 is disconnected, store1 rejects leader. Balancer will not create // any operators. - bs, err := schedule.CreateScheduler(BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + bs, err := schedule.CreateScheduler(BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) op = bs.Schedule(tc) c.Assert(op, IsNil) // Can't evict leader from store2, neither. - el, err := schedule.CreateScheduler(EvictLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"2"})) + el, err := schedule.CreateScheduler(EvictLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"2"})) c.Assert(err, IsNil) op = el.Schedule(tc) c.Assert(op, IsNil) @@ -248,7 +248,7 @@ func (s *testShuffleHotRegionSchedulerSuite) TestBalance(c *C) { tc.SetMaxReplicas(3) tc.SetLocationLabels([]string{"zone", "host"}) tc.DisableFeature(versioninfo.JointConsensus) - hb, err := schedule.CreateScheduler(ShuffleHotRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder("shuffle-hot-region", []string{"", ""})) + hb, err := schedule.CreateScheduler(ShuffleHotRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder("shuffle-hot-region", []string{"", ""})) c.Assert(err, IsNil) s.checkBalance(c, tc, opt, hb) @@ -307,7 +307,7 @@ func (s *testHotRegionSchedulerSuite) TestAbnormalReplica(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) tc.SetLeaderScheduleLimit(0) - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) c.Assert(err, IsNil) tc.AddRegionStore(1, 3) @@ -346,7 +346,7 @@ func (s *testEvictLeaderSuite) TestEvictLeader(c *C) { tc.AddLeaderRegion(2, 2, 1) tc.AddLeaderRegion(3, 3, 1) - sl, err := schedule.CreateScheduler(EvictLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"1"})) + sl, err := schedule.CreateScheduler(EvictLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"1"})) c.Assert(err, IsNil) c.Assert(sl.IsScheduleAllowed(tc), IsTrue) op := sl.Schedule(tc) @@ -363,7 +363,7 @@ func (s *testShuffleRegionSuite) TestShuffle(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) - sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) + sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) c.Assert(err, IsNil) c.Assert(sl.IsScheduleAllowed(tc), IsTrue) c.Assert(sl.Schedule(tc), IsNil) @@ -427,7 +427,7 @@ func (s *testShuffleRegionSuite) TestRole(c *C) { }, peers[0]) tc.PutRegion(region) - sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) + sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) c.Assert(err, IsNil) conf := sl.(*shuffleRegionScheduler).conf @@ -449,7 +449,7 @@ func (s *testSpecialUseSuite) TestSpecialUseHotRegion(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() oc := schedule.NewOperatorController(ctx, nil, nil) - storage := core.NewStorage(kv.NewMemoryKV()) + storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) cd := schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""}) bs, err := schedule.CreateScheduler(BalanceRegionType, oc, storage, cd) c.Assert(err, IsNil) @@ -502,7 +502,7 @@ func (s *testSpecialUseSuite) TestSpecialUseReserved(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() oc := schedule.NewOperatorController(ctx, nil, nil) - storage := core.NewStorage(kv.NewMemoryKV()) + storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) cd := schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""}) bs, err := schedule.CreateScheduler(BalanceRegionType, oc, storage, cd) c.Assert(err, IsNil) @@ -549,7 +549,7 @@ func (s *testBalanceLeaderSchedulerWithRuleEnabledSuite) SetUpTest(c *C) { s.tc = mockcluster.NewCluster(s.opt) s.tc.SetEnablePlacementRules(true) s.oc = schedule.NewOperatorController(s.ctx, nil, nil) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) s.lb = lb } From 6f6d93a12f822eb5d5838a029a44f47a6a38059d Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 17 Sep 2020 12:32:49 +0800 Subject: [PATCH 07/37] fix tests Signed-off-by: Yi Wu --- tests/client/client_tls_test.go | 2 +- tests/server/cluster/cluster_test.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/client/client_tls_test.go b/tests/client/client_tls_test.go index 999c9492bfe..e1886edcff8 100644 --- a/tests/client/client_tls_test.go +++ b/tests/client/client_tls_test.go @@ -128,7 +128,7 @@ func (s *clientTLSTestSuite) testTLSReload( tlsInfo := cloneFunc() // 1. start cluster with valid certs clus, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { - conf.Security = grpcutil.SecurityConfig{ + conf.Security.SecurityConfig = grpcutil.SecurityConfig{ KeyPath: tlsInfo.KeyFile, CertPath: tlsInfo.CertFile, CAPath: tlsInfo.TrustedCAFile, diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 359e840b0a1..8abb6c1f727 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -607,7 +607,7 @@ func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { // PUT GET failed oldStorage := svr.GetStorage() - svr.SetStorage(core.NewStorage(&testErrorKV{}), nil, nil) + svr.SetStorage(core.NewStorage(&testErrorKV{}, nil, nil)) replicationCfg.MaxReplicas = 7 scheduleCfg.MaxSnapshotCount = 20 pdServerCfg.UseRegionStorage = false @@ -626,7 +626,7 @@ func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { svr.SetStorage(oldStorage) c.Assert(svr.SetReplicationConfig(*replicationCfg), IsNil) - svr.SetStorage(core.NewStorage(&testErrorKV{}), nil, nil) + svr.SetStorage(core.NewStorage(&testErrorKV{}, nil, nil)) c.Assert(svr.DeleteLabelProperty(typ, labelKey, labelValue), NotNil) c.Assert(persistOptions.GetLabelPropertyConfig()[typ][0].Key, Equals, "testKey") @@ -894,7 +894,7 @@ func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV()), nil, nil) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) id := leaderServer.GetAllocator() for _, addr := range storeAddrs { storeID, err := id.Alloc() @@ -981,7 +981,7 @@ func (s *clusterTestSuite) TestUpgradeStoreLimit(c *C) { bootstrapCluster(c, clusterID, grpcPDClient, "127.0.0.1:0") rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV()), nil, nil) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) store := newMetaStore(1, "127.0.1.1:0", "4.0.0", metapb.StoreState_Up, "test/store1") _, err = putStore(c, grpcPDClient, clusterID, store) c.Assert(err, IsNil) @@ -1039,7 +1039,7 @@ func (s *clusterTestSuite) TestStaleTermHeartbeat(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV()), nil, nil) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) var peers []*metapb.Peer id := leaderServer.GetAllocator() for _, addr := range storeAddrs { From 928de26c1a569e62f50bac3f43c3317df1fab588 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 18 Sep 2020 06:27:46 +0800 Subject: [PATCH 08/37] address comment in #2931 Signed-off-by: Yi Wu --- pkg/encryption/crypter.go | 30 ++++++++++++++++-------------- pkg/encryption/crypter_test.go | 6 +++--- pkg/encryption/region_crypter.go | 2 +- server/core/storage_test.go | 2 +- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/pkg/encryption/crypter.go b/pkg/encryption/crypter.go index eb8ebe87a48..d35b75dbff9 100644 --- a/pkg/encryption/crypter.go +++ b/pkg/encryption/crypter.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "io" "time" + "unsafe" "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/tikv/pd/pkg/errs" @@ -67,11 +68,11 @@ func KeyLength(method encryptionpb.EncryptionMethod) (int, error) { } } -// IvCtr represent IV bytes for CTR mode. -type IvCtr []byte +// IvCTR represent IV bytes for CTR mode. +type IvCTR []byte -// IvGcm represent IV bytes for GCM mode. -type IvGcm []byte +// IvGCM represent IV bytes for GCM mode. +type IvGCM []byte func newIV(ivLength int) ([]byte, error) { iv := make([]byte, ivLength) @@ -86,13 +87,13 @@ func newIV(ivLength int) ([]byte, error) { return iv, nil } -// NewIvCtr randomly generate an IV for CTR mode. -func NewIvCtr() (IvCtr, error) { +// NewIvCTR randomly generate an IV for CTR mode. +func NewIvCTR() (IvCTR, error) { return newIV(ivLengthCTR) } -// NewIvGcm randomly generate an IV for GCM mode. -func NewIvGcm() (IvGcm, error) { +// NewIvGCM randomly generate an IV for GCM mode. +func NewIvGCM() (IvGCM, error) { return newIV(ivLengthGCM) } @@ -104,14 +105,15 @@ func NewDataKey( if err != nil { return } - keyIDBuf := make([]byte, 8) + keyIDBufSize := unsafe.Sizeof(uint64(0)) + keyIDBuf := make([]byte, keyIDBufSize) n, err := io.ReadFull(rand.Reader, keyIDBuf) if err != nil { err = errs.ErrEncryptionNewDataKey.Wrap(err).GenWithStack( "fail to generate data key id") return } - if n != 8 { + if n != int(keyIDBufSize) { err = errs.ErrEncryptionNewDataKey.GenWithStack( "no enough random bytes to generate data key id, bytes %d", n) return @@ -145,7 +147,7 @@ func NewDataKey( func aesGcmEncryptImpl( key []byte, plaintext []byte, - iv IvGcm, + iv IvGCM, ) (ciphertext []byte, err error) { block, err := aes.NewCipher(key) if err != nil { @@ -166,8 +168,8 @@ func aesGcmEncryptImpl( func AesGcmEncrypt( key []byte, plaintext []byte, -) (ciphertext []byte, iv IvGcm, err error) { - iv, err = NewIvGcm() +) (ciphertext []byte, iv IvGCM, err error) { + iv, err = NewIvGCM() if err != nil { return } @@ -180,7 +182,7 @@ func AesGcmEncrypt( func AesGcmDecrypt( key []byte, ciphertext []byte, - iv IvGcm, + iv IvGCM, ) (plaintext []byte, err error) { if len(iv) != ivLengthGCM { err = errs.ErrEncryptionGCMDecrypt.GenWithStack("unexpected gcm iv length %d", len(iv)) diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index 65b87e10ec8..d140117a435 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -55,10 +55,10 @@ func (s *testCrypterSuite) TestKeyLength(c *C) { } func (s *testCrypterSuite) TestNewIv(c *C) { - ivCtr, err := NewIvCtr() + ivCtr, err := NewIvCTR() c.Assert(err, IsNil) c.Assert(len([]byte(ivCtr)), Equals, ivLengthCTR) - ivGcm, err := NewIvGcm() + ivGcm, err := NewIvGCM() c.Assert(err, IsNil) c.Assert(len([]byte(ivGcm)), Equals, ivLengthGCM) } @@ -88,7 +88,7 @@ func (s *testCrypterSuite) TestAesGcmCrypter(c *C) { // encrypt ivBytes, err := hex.DecodeString("ba432b70336c40c39ba14c1b") c.Assert(err, IsNil) - iv := IvGcm(ivBytes) + iv := IvGCM(ivBytes) ciphertext, err := aesGcmEncryptImpl(key, plaintext, iv) c.Assert(err, IsNil) c.Assert(len([]byte(iv)), Equals, ivLengthGCM) diff --git a/pkg/encryption/region_crypter.go b/pkg/encryption/region_crypter.go index 31115f9d657..bd379d4b92a 100644 --- a/pkg/encryption/region_crypter.go +++ b/pkg/encryption/region_crypter.go @@ -63,7 +63,7 @@ func EncryptRegion(region *metapb.Region, keyManager KeyManager) error { if err != nil { return err } - iv, err := NewIvCtr() + iv, err := NewIvCTR() if err != nil { return err } diff --git a/server/core/storage_test.go b/server/core/storage_test.go index 2f2a7e05762..33d6668ff07 100644 --- a/server/core/storage_test.go +++ b/server/core/storage_test.go @@ -68,8 +68,8 @@ func (s *testKVSuite) TestBasic(c *C) { c.Assert(storage.SaveRegion(region), IsNil) newRegion := &metapb.Region{} ok, err = storage.LoadRegion(123, newRegion) - c.Assert(ok, IsTrue) c.Assert(err, IsNil) + c.Assert(ok, IsTrue) c.Assert(newRegion, DeepEquals, region) err = storage.DeleteRegion(region) c.Assert(err, IsNil) From 9b6bcfdec668464c9b4fe17b4c2362de6df7a453 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 18 Sep 2020 06:55:24 +0800 Subject: [PATCH 09/37] fix loadRegion Signed-off-by: Yi Wu --- server/core/region_storage.go | 3 ++- server/core/storage.go | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/server/core/region_storage.go b/server/core/region_storage.go index 94b99debbff..2d4cc694157 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -150,9 +150,10 @@ func saveRegion( func loadRegion( kv kv.Base, encryptionKeyManager *encryption.KeyManager, + regionID uint64, region *metapb.Region, ) (ok bool, err error) { - value, err := kv.Load(regionPath(region.GetId())) + value, err := kv.Load(regionPath(regionID)) if err != nil { return false, err } diff --git a/server/core/storage.go b/server/core/storage.go index e43fa852c23..584c778b0a7 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -160,9 +160,9 @@ func (s *Storage) DeleteStore(store *metapb.Store) error { // LoadRegion loads one region from storage. func (s *Storage) LoadRegion(regionID uint64, region *metapb.Region) (ok bool, err error) { if atomic.LoadInt32(&s.useRegionStorage) > 0 { - return loadRegion(s.regionStorage, s.encryptionKeyManager, region) + return loadRegion(s.regionStorage, s.encryptionKeyManager, regionID, region) } - return loadRegion(s.Base, s.encryptionKeyManager, region) + return loadRegion(s.Base, s.encryptionKeyManager, regionID, region) } // LoadRegions loads all regions from storage to RegionsInfo. From bdd5de47a1df7ea01a60657669ee6d37154821aa Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 18 Sep 2020 07:09:18 +0800 Subject: [PATCH 10/37] rename encryption_key_manager package Signed-off-by: Yi Wu --- server/core/region_storage.go | 22 +++++++++---------- server/core/storage.go | 6 ++--- .../key_manager.go | 2 +- server/server.go | 6 ++--- 4 files changed, 18 insertions(+), 18 deletions(-) rename server/{encryption => encryption_key_manager}/key_manager.go (98%) diff --git a/server/core/region_storage.go b/server/core/region_storage.go index 2d4cc694157..940e97bbfe4 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -22,9 +22,9 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" - crypter "github.com/tikv/pd/pkg/encryption" + "github.com/tikv/pd/pkg/encryption" "github.com/tikv/pd/pkg/errs" - "github.com/tikv/pd/server/encryption" + ekm "github.com/tikv/pd/server/encryption_key_manager" "github.com/tikv/pd/server/kv" ) @@ -33,7 +33,7 @@ var dirtyFlushTick = time.Second // RegionStorage is used to save regions. type RegionStorage struct { *kv.LeveldbKV - encryptionKeyManager *encryption.KeyManager + encryptionKeyManager *ekm.KeyManager mu sync.RWMutex batchRegions map[string]*metapb.Region batchSize int @@ -55,7 +55,7 @@ const ( func NewRegionStorage( ctx context.Context, path string, - encryptionKeyManager *encryption.KeyManager, + encryptionKeyManager *ekm.KeyManager, ) (*RegionStorage, error) { levelDB, err := kv.NewLeveldbKV(path) if err != nil { @@ -105,7 +105,7 @@ func (s *RegionStorage) backgroundFlush() { // SaveRegion saves one region to storage. func (s *RegionStorage) SaveRegion(region *metapb.Region) error { - err := crypter.EncryptRegion(region, s.encryptionKeyManager) + err := encryption.EncryptRegion(region, s.encryptionKeyManager) if err != nil { return err } @@ -133,10 +133,10 @@ func deleteRegion(kv kv.Base, region *metapb.Region) error { func saveRegion( kv kv.Base, - encryptionKeyManager *encryption.KeyManager, + encryptionKeyManager *ekm.KeyManager, region *metapb.Region, ) error { - err := crypter.EncryptRegion(region, encryptionKeyManager) + err := encryption.EncryptRegion(region, encryptionKeyManager) if err != nil { return err } @@ -149,7 +149,7 @@ func saveRegion( func loadRegion( kv kv.Base, - encryptionKeyManager *encryption.KeyManager, + encryptionKeyManager *ekm.KeyManager, regionID uint64, region *metapb.Region, ) (ok bool, err error) { @@ -164,13 +164,13 @@ func loadRegion( if err != nil { return true, errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() } - err = crypter.DecryptRegion(region, encryptionKeyManager) + err = encryption.DecryptRegion(region, encryptionKeyManager) return true, err } func loadRegions( kv kv.Base, - encryptionKeyManager *encryption.KeyManager, + encryptionKeyManager *ekm.KeyManager, f func(region *RegionInfo) []*RegionInfo, ) error { nextID := uint64(0) @@ -195,7 +195,7 @@ func loadRegions( if err := region.Unmarshal([]byte(s)); err != nil { return errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() } - if err = crypter.DecryptRegion(region, encryptionKeyManager); err != nil { + if err = encryption.DecryptRegion(region, encryptionKeyManager); err != nil { return err } diff --git a/server/core/storage.go b/server/core/storage.go index 584c778b0a7..830278936e2 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -28,7 +28,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" "github.com/tikv/pd/pkg/errs" - "github.com/tikv/pd/server/encryption" + ekm "github.com/tikv/pd/server/encryption_key_manager" "github.com/tikv/pd/server/kv" "go.etcd.io/etcd/clientv3" ) @@ -55,7 +55,7 @@ const ( type Storage struct { kv.Base regionStorage *RegionStorage - encryptionKeyManager *encryption.KeyManager + encryptionKeyManager *ekm.KeyManager useRegionStorage int32 regionLoaded int32 mu sync.Mutex @@ -65,7 +65,7 @@ type Storage struct { func NewStorage( base kv.Base, regionStorage *RegionStorage, - encryptionKeyManager *encryption.KeyManager, + encryptionKeyManager *ekm.KeyManager, ) *Storage { return &Storage{ Base: base, diff --git a/server/encryption/key_manager.go b/server/encryption_key_manager/key_manager.go similarity index 98% rename from server/encryption/key_manager.go rename to server/encryption_key_manager/key_manager.go index 16f2a9c73bf..fc64f64b602 100644 --- a/server/encryption/key_manager.go +++ b/server/encryption_key_manager/key_manager.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package encryption +package encryption_key_manager import ( "github.com/pingcap/kvproto/pkg/encryptionpb" diff --git a/server/server.go b/server/server.go index 871ae4c33b3..64d84b6737d 100644 --- a/server/server.go +++ b/server/server.go @@ -47,7 +47,7 @@ import ( "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" - "github.com/tikv/pd/server/encryption" + ekm "github.com/tikv/pd/server/encryption_key_manager" "github.com/tikv/pd/server/id" "github.com/tikv/pd/server/kv" "github.com/tikv/pd/server/member" @@ -117,7 +117,7 @@ type Server struct { // a unique ID. idAllocator *id.AllocatorImpl // for encryption - encryptionKeyManager *encryption.KeyManager + encryptionKeyManager *ekm.KeyManager // for storage operation. storage *core.Storage // for basicCluster operation. @@ -360,7 +360,7 @@ func (s *Server) startServer(ctx context.Context) error { return err } kvBase := kv.NewEtcdKVBase(s.client, s.rootPath) - encryptionKeyManager, err := encryption.NewKeyManager(kvBase, &s.cfg.Security.Encryption) + encryptionKeyManager, err := ekm.NewKeyManager(kvBase, &s.cfg.Security.Encryption) if err != nil { return err } From 90de5ef64d519d71d866d66601cb3df8207dc309 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 18 Sep 2020 07:19:24 +0800 Subject: [PATCH 11/37] fix lint Signed-off-by: Yi Wu --- pkg/encryption/config.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/encryption/config.go b/pkg/encryption/config.go index 935cb4de428..c2fa6ed6392 100644 --- a/pkg/encryption/config.go +++ b/pkg/encryption/config.go @@ -46,7 +46,7 @@ type Config struct { func (c *Config) Adjust() error { if len(c.DataEncryptionMethod) == 0 { - c.DataEncryptionMethod = methodPlaintext + c.DataEncryptionMethod = defaultDataEncryptionMethod } else { if _, err := c.GetMethod(); err != nil { return err @@ -100,7 +100,7 @@ func (c *Config) GetMasterKey() (*encryptionpb.MasterKey, error) { Backend: &encryptionpb.MasterKey_Kms{ Kms: &encryptionpb.MasterKeyKms{ Vendor: kmsVendorAWS, - KeyId: c.MasterKey.KmsKeyId, + KeyId: c.MasterKey.KmsKeyID, Region: c.MasterKey.KmsRegion, Endpoint: c.MasterKey.KmsEndpoint, }, @@ -130,7 +130,7 @@ type MasterKeyConfig struct { type MasterKeyKMSConfig struct { // KMS CMK key id. - KmsKeyId string `toml:"key-id" json:"key-id"` + KmsKeyID string `toml:"key-id" json:"key-id"` // KMS region of the CMK. KmsRegion string `toml:"region" json:"region"` // Custom endpoint to access KMS. From c16ecd84b10f597cbda32d6f7fa9c7e680bb06eb Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 18 Sep 2020 07:40:43 +0800 Subject: [PATCH 12/37] fix lint Signed-off-by: Yi Wu --- server/core/region_storage.go | 12 ++++++------ server/core/storage.go | 6 +++--- .../key_manager.go | 2 +- server/server.go | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) rename server/{encryption_key_manager => encryptionkm}/key_manager.go (98%) diff --git a/server/core/region_storage.go b/server/core/region_storage.go index 940e97bbfe4..0ca50765c79 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/pkg/encryption" "github.com/tikv/pd/pkg/errs" - ekm "github.com/tikv/pd/server/encryption_key_manager" + "github.com/tikv/pd/server/encryptionkm" "github.com/tikv/pd/server/kv" ) @@ -33,7 +33,7 @@ var dirtyFlushTick = time.Second // RegionStorage is used to save regions. type RegionStorage struct { *kv.LeveldbKV - encryptionKeyManager *ekm.KeyManager + encryptionKeyManager *encryptionkm.KeyManager mu sync.RWMutex batchRegions map[string]*metapb.Region batchSize int @@ -55,7 +55,7 @@ const ( func NewRegionStorage( ctx context.Context, path string, - encryptionKeyManager *ekm.KeyManager, + encryptionKeyManager *encryptionkm.KeyManager, ) (*RegionStorage, error) { levelDB, err := kv.NewLeveldbKV(path) if err != nil { @@ -133,7 +133,7 @@ func deleteRegion(kv kv.Base, region *metapb.Region) error { func saveRegion( kv kv.Base, - encryptionKeyManager *ekm.KeyManager, + encryptionKeyManager *encryptionkm.KeyManager, region *metapb.Region, ) error { err := encryption.EncryptRegion(region, encryptionKeyManager) @@ -149,7 +149,7 @@ func saveRegion( func loadRegion( kv kv.Base, - encryptionKeyManager *ekm.KeyManager, + encryptionKeyManager *encryptionkm.KeyManager, regionID uint64, region *metapb.Region, ) (ok bool, err error) { @@ -170,7 +170,7 @@ func loadRegion( func loadRegions( kv kv.Base, - encryptionKeyManager *ekm.KeyManager, + encryptionKeyManager *encryptionkm.KeyManager, f func(region *RegionInfo) []*RegionInfo, ) error { nextID := uint64(0) diff --git a/server/core/storage.go b/server/core/storage.go index 830278936e2..750ffab0245 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -28,7 +28,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" "github.com/tikv/pd/pkg/errs" - ekm "github.com/tikv/pd/server/encryption_key_manager" + "github.com/tikv/pd/server/encryptionkm" "github.com/tikv/pd/server/kv" "go.etcd.io/etcd/clientv3" ) @@ -55,7 +55,7 @@ const ( type Storage struct { kv.Base regionStorage *RegionStorage - encryptionKeyManager *ekm.KeyManager + encryptionKeyManager *encryptionkm.KeyManager useRegionStorage int32 regionLoaded int32 mu sync.Mutex @@ -65,7 +65,7 @@ type Storage struct { func NewStorage( base kv.Base, regionStorage *RegionStorage, - encryptionKeyManager *ekm.KeyManager, + encryptionKeyManager *encryptionkm.KeyManager, ) *Storage { return &Storage{ Base: base, diff --git a/server/encryption_key_manager/key_manager.go b/server/encryptionkm/key_manager.go similarity index 98% rename from server/encryption_key_manager/key_manager.go rename to server/encryptionkm/key_manager.go index fc64f64b602..945f5348118 100644 --- a/server/encryption_key_manager/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package encryption_key_manager +package encryptionkm import ( "github.com/pingcap/kvproto/pkg/encryptionpb" diff --git a/server/server.go b/server/server.go index 64d84b6737d..2134eeefb29 100644 --- a/server/server.go +++ b/server/server.go @@ -47,7 +47,7 @@ import ( "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" - ekm "github.com/tikv/pd/server/encryption_key_manager" + "github.com/tikv/pd/server/encryptionkm" "github.com/tikv/pd/server/id" "github.com/tikv/pd/server/kv" "github.com/tikv/pd/server/member" @@ -117,7 +117,7 @@ type Server struct { // a unique ID. idAllocator *id.AllocatorImpl // for encryption - encryptionKeyManager *ekm.KeyManager + encryptionKeyManager *encryptionkm.KeyManager // for storage operation. storage *core.Storage // for basicCluster operation. @@ -360,7 +360,7 @@ func (s *Server) startServer(ctx context.Context) error { return err } kvBase := kv.NewEtcdKVBase(s.client, s.rootPath) - encryptionKeyManager, err := ekm.NewKeyManager(kvBase, &s.cfg.Security.Encryption) + encryptionKeyManager, err := encryptionkm.NewKeyManager(kvBase, &s.cfg.Security.Encryption) if err != nil { return err } From 2428248f520b7f5260fff88a8a069ec3a26501ae Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 18 Sep 2020 08:27:41 +0800 Subject: [PATCH 13/37] fix comments Signed-off-by: Yi Wu --- pkg/encryption/config.go | 7 +++++++ server/core/region_storage.go | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/encryption/config.go b/pkg/encryption/config.go index c2fa6ed6392..e690f8bfec5 100644 --- a/pkg/encryption/config.go +++ b/pkg/encryption/config.go @@ -35,6 +35,7 @@ const ( defaultDataKeyRotationPeriod = "168h" // 7 days ) +// Config define the encryption config structure. type Config struct { // Encryption method to use for PD data. DataEncryptionMethod string `toml:"data-encryption-method" json:"data-encryption-method"` @@ -44,6 +45,7 @@ type Config struct { MasterKey MasterKeyConfig `toml:"master-key" json:"master-key"` } +// Adjust validates the config and sets default values. func (c *Config) Adjust() error { if len(c.DataEncryptionMethod) == 0 { c.DataEncryptionMethod = defaultDataEncryptionMethod @@ -71,6 +73,7 @@ func (c *Config) Adjust() error { return nil } +// GetMethod gets the encryption method. func (c *Config) GetMethod() (encryptionpb.EncryptionMethod, error) { switch c.DataEncryptionMethod { case methodPlaintext: @@ -87,6 +90,7 @@ func (c *Config) GetMethod() (encryptionpb.EncryptionMethod, error) { } } +// GetMasterKey gets the master key config. func (c *Config) GetMasterKey() (*encryptionpb.MasterKey, error) { switch c.MasterKey.Type { case masterKeyTypePlaintext: @@ -120,6 +124,7 @@ func (c *Config) GetMasterKey() (*encryptionpb.MasterKey, error) { } } +// MasterKeyConfig defines master key config structure. type MasterKeyConfig struct { // Master key type, one of "plaintext", "kms" or "file". Type string `toml:"type" json:"type"` @@ -128,6 +133,7 @@ type MasterKeyConfig struct { MasterKeyFileConfig } +// MasterKeyKMSConfig defines a KMS master key config structure. type MasterKeyKMSConfig struct { // KMS CMK key id. KmsKeyID string `toml:"key-id" json:"key-id"` @@ -137,6 +143,7 @@ type MasterKeyKMSConfig struct { KmsEndpoint string `toml:"endpoint" json:"endpoint"` } +// MasterKeyFileConfig defines a file-based master key config structure. type MasterKeyFileConfig struct { // Master key file path. FilePath string `toml:"path" json:"path"` diff --git a/server/core/region_storage.go b/server/core/region_storage.go index 0ca50765c79..3fead555f68 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -51,7 +51,7 @@ const ( defaultBatchSize = 100 ) -// newRegionStorage returns a region storage that is used to save regions. +// NewRegionStorage returns a region storage that is used to save regions. func NewRegionStorage( ctx context.Context, path string, From ce83a4235d9c0ee1bd4c38facc0302e774c296a0 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 18 Sep 2020 08:50:59 +0800 Subject: [PATCH 14/37] fix comment Signed-off-by: Yi Wu --- server/core/storage.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/core/storage.go b/server/core/storage.go index 750ffab0245..5f333c696c4 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -110,6 +110,7 @@ func (s *Storage) storeRegionWeightPath(storeID uint64) string { return path.Join(schedulePath, "store_weight", fmt.Sprintf("%020d", storeID), "region") } +// EncryptionKeysPath returns the path to save encryption keys. func (s *Storage) EncryptionKeysPath() string { return path.Join(encryptionKeysPath, "keys") } From 073f13e8630adf89c129df86056dd2a30c5773e7 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 25 Sep 2020 06:37:50 +0800 Subject: [PATCH 15/37] use option pattern Signed-off-by: Yi Wu --- pkg/component/manager_test.go | 2 +- server/cluster/cluster_test.go | 30 +++++-------- server/cluster/cluster_worker_test.go | 4 +- server/cluster/coordinator.go | 7 +-- server/cluster/coordinator_test.go | 8 ++-- server/config/config_test.go | 4 +- server/core/storage.go | 37 +++++++++++++--- server/core/storage_test.go | 20 ++++----- server/replication/replication_mode_test.go | 10 ++--- .../schedule/placement/rule_manager_test.go | 2 +- server/schedulers/balance_test.go | 43 +++++++++---------- server/schedulers/hot_test.go | 18 ++++---- server/schedulers/scheduler_test.go | 28 ++++++------ server/server.go | 6 ++- server/statistics/region_collection_test.go | 2 +- tests/server/cluster/cluster_test.go | 12 +++--- 16 files changed, 122 insertions(+), 111 deletions(-) diff --git a/pkg/component/manager_test.go b/pkg/component/manager_test.go index fe360294c1c..3c02a162729 100644 --- a/pkg/component/manager_test.go +++ b/pkg/component/manager_test.go @@ -31,7 +31,7 @@ var _ = Suite(&testManagerSuite{}) type testManagerSuite struct{} func (s *testManagerSuite) TestManager(c *C) { - m := NewManager(core.NewStorage(kv.NewMemoryKV(), nil, nil)) + m := NewManager(core.NewStorage(kv.NewMemoryKV())) // register legal address c.Assert(m.Register("c1", "127.0.0.1:1"), IsNil) c.Assert(m.Register("c1", "127.0.0.1:2"), IsNil) diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 55e7aa76af8..7c280226a95 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -46,8 +46,7 @@ type testClusterInfoSuite struct{} func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) n, np := uint64(3), uint64(3) stores := newTestStores(n) @@ -96,8 +95,7 @@ func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { func (s *testClusterInfoSuite) TestFilterUnhealthyStore(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) stores := newTestStores(3) for _, store := range stores { @@ -129,8 +127,7 @@ func (s *testClusterInfoSuite) TestFilterUnhealthyStore(c *C) { func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) n, np := uint64(3), uint64(3) @@ -355,8 +352,7 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { func (s *testClusterInfoSuite) TestRegionFlowChanged(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} processRegions := func(regions []*core.RegionInfo) { for _, r := range regions { @@ -384,8 +380,7 @@ func (s *testClusterInfoSuite) TestRegionFlowChanged(c *C) { func (s *testClusterInfoSuite) TestConcurrentRegionHeartbeat(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} regions = core.SplitRegions(regions) @@ -446,8 +441,7 @@ func heartbeatRegions(c *C, cluster *RaftCluster, regions []*core.RegionInfo) { func (s *testClusterInfoSuite) TestHeartbeatSplit(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) // 1: [nil, nil) region1 := core.NewRegionInfo(&metapb.Region{Id: 1, RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) @@ -486,8 +480,7 @@ func (s *testClusterInfoSuite) TestHeartbeatSplit(c *C) { func (s *testClusterInfoSuite) TestRegionSplitAndMerge(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} @@ -594,8 +587,7 @@ func (s *testRegionsInfoSuite) Test(c *C) { regions := newTestRegions(n, np) _, opts, err := newTestScheduleConfig() c.Assert(err, IsNil) - tc := newTestRaftCluster( - mockid.NewIDAllocator(), opts, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + tc := newTestRaftCluster(mockid.NewIDAllocator(), opts, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) cache := tc.core.Regions for i := uint64(0); i < n; i++ { @@ -705,8 +697,7 @@ type testGetStoresSuite struct { func (s *testGetStoresSuite) SetUpSuite(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) s.cluster = cluster stores := newTestStores(200) @@ -739,8 +730,7 @@ func newTestScheduleConfig() (*config.ScheduleConfig, *config.PersistOptions, er } func newTestCluster(opt *config.PersistOptions) *testCluster { - rc := newTestRaftCluster( - mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + rc := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) return &testCluster{RaftCluster: rc} } diff --git a/server/cluster/cluster_worker_test.go b/server/cluster/cluster_worker_test.go index 03790368993..837640959ec 100644 --- a/server/cluster/cluster_worker_test.go +++ b/server/cluster/cluster_worker_test.go @@ -31,7 +31,7 @@ type testClusterWorkerSuite struct{} func (s *testClusterWorkerSuite) TestReportSplit(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) left := &metapb.Region{Id: 1, StartKey: []byte("a"), EndKey: []byte("b")} right := &metapb.Region{Id: 2, StartKey: []byte("b"), EndKey: []byte("c")} _, err = cluster.HandleReportSplit(&pdpb.ReportSplitRequest{Left: left, Right: right}) @@ -43,7 +43,7 @@ func (s *testClusterWorkerSuite) TestReportSplit(c *C) { func (s *testClusterWorkerSuite) TestReportBatchSplit(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) - cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV(), nil, nil), core.NewBasicCluster()) + cluster := newTestRaftCluster(mockid.NewIDAllocator(), opt, core.NewStorage(kv.NewMemoryKV()), core.NewBasicCluster()) regions := []*metapb.Region{ {Id: 1, StartKey: []byte(""), EndKey: []byte("a")}, {Id: 2, StartKey: []byte("a"), EndKey: []byte("b")}, diff --git a/server/cluster/coordinator.go b/server/cluster/coordinator.go index 4f2422f05d4..b2335565bf9 100644 --- a/server/cluster/coordinator.go +++ b/server/cluster/coordinator.go @@ -590,12 +590,7 @@ func (c *coordinator) removeOptScheduler(o *config.PersistOptions, name string) for i, schedulerCfg := range v.Schedulers { // To create a temporary scheduler is just used to get scheduler's name decoder := schedule.ConfigSliceDecoder(schedulerCfg.Type, schedulerCfg.Args) - tmp, err := schedule.CreateScheduler( - schedulerCfg.Type, - schedule.NewOperatorController(c.ctx, nil, nil), - core.NewStorage(kv.NewMemoryKV(), nil, nil), - decoder, - ) + tmp, err := schedule.CreateScheduler(schedulerCfg.Type, schedule.NewOperatorController(c.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), decoder) if err != nil { return err } diff --git a/server/cluster/coordinator_test.go b/server/cluster/coordinator_test.go index bd883d07910..b7e9959e1c8 100644 --- a/server/cluster/coordinator_test.go +++ b/server/cluster/coordinator_test.go @@ -601,12 +601,12 @@ func (s *testCoordinatorSuite) TestAddScheduler(c *C) { c.Assert(tc.addLeaderRegion(3, 3, 1, 2), IsNil) oc := co.opController - gls, err := schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"0"})) + gls, err := schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"0"})) c.Assert(err, IsNil) c.Assert(co.addScheduler(gls), NotNil) c.Assert(co.removeScheduler(gls.GetName()), NotNil) - gls, err = schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"})) + gls, err = schedule.CreateScheduler(schedulers.GrantLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"})) c.Assert(err, IsNil) c.Assert(co.addScheduler(gls), IsNil) @@ -1014,7 +1014,7 @@ func (s *testScheduleControllerSuite) TestController(c *C) { c.Assert(tc.addLeaderRegion(1, 1), IsNil) c.Assert(tc.addLeaderRegion(2, 2), IsNil) - scheduler, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) + scheduler, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) lb := &mockLimitScheduler{ Scheduler: scheduler, @@ -1098,7 +1098,7 @@ func (s *testScheduleControllerSuite) TestInterval(c *C) { _, co, cleanup := prepare(nil, nil, nil, c) defer cleanup() - lb, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, co.opController, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, co.opController, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) sc := newScheduleController(co, lb) diff --git a/server/config/config_test.go b/server/config/config_test.go index b63a164f5ff..9464cd0247c 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -60,7 +60,7 @@ func (s *testConfigSuite) TestBadFormatJoinAddr(c *C) { func (s *testConfigSuite) TestReloadConfig(c *C) { opt, err := newTestScheduleOption() c.Assert(err, IsNil) - storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) + storage := core.NewStorage(kv.NewMemoryKV()) scheduleCfg := opt.GetScheduleConfig() scheduleCfg.MaxSnapshotCount = 10 opt.SetMaxReplicas(5) @@ -100,7 +100,7 @@ func (s *testConfigSuite) TestReloadUpgrade(c *C) { Schedule: *opt.GetScheduleConfig(), Replication: *opt.GetReplicationConfig(), } - storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) + storage := core.NewStorage(kv.NewMemoryKV()) c.Assert(storage.SaveConfig(old), IsNil) newOpt, err := newTestScheduleOption() diff --git a/server/core/storage.go b/server/core/storage.go index 5f333c696c4..d239bc88168 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -61,16 +61,39 @@ type Storage struct { mu sync.Mutex } +// StorageOpt represents available options to create Storage. +type StorageOpt struct { + regionStorage *RegionStorage + encryptionKeyManager *encryptionkm.KeyManager +} + +// StorageOption configures StorageOpt +type StorageOption func(*StorageOpt) + +// WithRegionStorage sets RegionStorage to the Storage +func WithRegionStorage(regionStorage *RegionStorage) StorageOption { + return func(opt *StorageOpt) { + opt.regionStorage = regionStorage + } +} + +// WithEncryptionManager sets EncryptionManager to the Storage +func WithEncryptionKeyManager(encryptionKeyManager *encryptionkm.KeyManager) StorageOption { + return func(opt *StorageOpt) { + opt.encryptionKeyManager = encryptionKeyManager + } +} + // NewStorage creates Storage instance with Base. -func NewStorage( - base kv.Base, - regionStorage *RegionStorage, - encryptionKeyManager *encryptionkm.KeyManager, -) *Storage { +func NewStorage(base kv.Base, opts ...StorageOption) *Storage { + options := &StorageOpt{} + for _, opt := range opts { + opt(options) + } return &Storage{ Base: base, - regionStorage: regionStorage, - encryptionKeyManager: encryptionKeyManager, + regionStorage: options.regionStorage, + encryptionKeyManager: options.encryptionKeyManager, } } diff --git a/server/core/storage_test.go b/server/core/storage_test.go index 33d6668ff07..8ccbbffda6f 100644 --- a/server/core/storage_test.go +++ b/server/core/storage_test.go @@ -34,7 +34,7 @@ type testKVSuite struct { } func (s *testKVSuite) TestBasic(c *C) { - storage := NewStorage(kv.NewMemoryKV(), nil, nil) + storage := NewStorage(kv.NewMemoryKV()) c.Assert(storage.storePath(123), Equals, "raft/s/00000000000000000123") c.Assert(regionPath(123), Equals, "raft/r/00000000000000000123") @@ -68,8 +68,8 @@ func (s *testKVSuite) TestBasic(c *C) { c.Assert(storage.SaveRegion(region), IsNil) newRegion := &metapb.Region{} ok, err = storage.LoadRegion(123, newRegion) - c.Assert(err, IsNil) c.Assert(ok, IsTrue) + c.Assert(err, IsNil) c.Assert(newRegion, DeepEquals, region) err = storage.DeleteRegion(region) c.Assert(err, IsNil) @@ -93,7 +93,7 @@ func mustSaveStores(c *C, s *Storage, n int) []*metapb.Store { } func (s *testKVSuite) TestLoadStores(c *C) { - storage := NewStorage(kv.NewMemoryKV(), nil, nil) + storage := NewStorage(kv.NewMemoryKV()) cache := NewStoresInfo() n := 10 @@ -107,7 +107,7 @@ func (s *testKVSuite) TestLoadStores(c *C) { } func (s *testKVSuite) TestStoreWeight(c *C) { - storage := NewStorage(kv.NewMemoryKV(), nil, nil) + storage := NewStorage(kv.NewMemoryKV()) cache := NewStoresInfo() const n = 3 @@ -138,7 +138,7 @@ func mustSaveRegions(c *C, s *Storage, n int) []*metapb.Region { } func (s *testKVSuite) TestLoadRegions(c *C) { - storage := NewStorage(kv.NewMemoryKV(), nil, nil) + storage := NewStorage(kv.NewMemoryKV()) cache := NewRegionsInfo() n := 10 @@ -152,7 +152,7 @@ func (s *testKVSuite) TestLoadRegions(c *C) { } func (s *testKVSuite) TestLoadRegionsToCache(c *C) { - storage := NewStorage(kv.NewMemoryKV(), nil, nil) + storage := NewStorage(kv.NewMemoryKV()) cache := NewRegionsInfo() n := 10 @@ -171,7 +171,7 @@ func (s *testKVSuite) TestLoadRegionsToCache(c *C) { } func (s *testKVSuite) TestLoadRegionsExceedRangeLimit(c *C) { - storage := NewStorage(&KVWithMaxRangeLimit{Base: kv.NewMemoryKV(), rangeLimit: 500}, nil, nil) + storage := NewStorage(&KVWithMaxRangeLimit{Base: kv.NewMemoryKV(), rangeLimit: 500}) cache := NewRegionsInfo() n := 1000 @@ -184,7 +184,7 @@ func (s *testKVSuite) TestLoadRegionsExceedRangeLimit(c *C) { } func (s *testKVSuite) TestLoadGCSafePoint(c *C) { - storage := NewStorage(kv.NewMemoryKV(), nil, nil) + storage := NewStorage(kv.NewMemoryKV()) testData := []uint64{0, 1, 2, 233, 2333, 23333333333, math.MaxUint64} r, e := storage.LoadGCSafePoint() @@ -201,7 +201,7 @@ func (s *testKVSuite) TestLoadGCSafePoint(c *C) { func (s *testKVSuite) TestSaveServiceGCSafePoint(c *C) { mem := kv.NewMemoryKV() - storage := NewStorage(mem, nil, nil) + storage := NewStorage(mem) expireAt := time.Now().Add(100 * time.Second).Unix() serviceSafePoints := []*ServiceSafePoint{ {"1", expireAt, 1}, @@ -233,7 +233,7 @@ func (s *testKVSuite) TestSaveServiceGCSafePoint(c *C) { func (s *testKVSuite) TestLoadMinServiceGCSafePoint(c *C) { mem := kv.NewMemoryKV() - storage := NewStorage(mem, nil, nil) + storage := NewStorage(mem) expireAt := time.Now().Add(1000 * time.Second).Unix() serviceSafePoints := []*ServiceSafePoint{ {"1", 0, 1}, diff --git a/server/replication/replication_mode_test.go b/server/replication/replication_mode_test.go index b85f5d31002..7b4c2138524 100644 --- a/server/replication/replication_mode_test.go +++ b/server/replication/replication_mode_test.go @@ -37,7 +37,7 @@ var _ = Suite(&testReplicationMode{}) type testReplicationMode struct{} func (s *testReplicationMode) TestInitial(c *C) { - store := core.NewStorage(kv.NewMemoryKV(), nil, nil) + store := core.NewStorage(kv.NewMemoryKV()) conf := config.ReplicationModeConfig{ReplicationMode: modeMajority} cluster := mockcluster.NewCluster(config.NewTestOptions()) rep, err := NewReplicationModeManager(conf, store, cluster, nil) @@ -67,7 +67,7 @@ func (s *testReplicationMode) TestInitial(c *C) { } func (s *testReplicationMode) TestStatus(c *C) { - store := core.NewStorage(kv.NewMemoryKV(), nil, nil) + store := core.NewStorage(kv.NewMemoryKV()) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "dr-label", WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, @@ -137,7 +137,7 @@ func (rep *mockFileReplicator) ReplicateFileToAllMembers(context.Context, string } func (s *testReplicationMode) TestStateSwitch(c *C) { - store := core.NewStorage(kv.NewMemoryKV(), nil, nil) + store := core.NewStorage(kv.NewMemoryKV()) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", Primary: "zone1", @@ -240,7 +240,7 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { } func (s *testReplicationMode) TestAsynctimeout(c *C) { - store := core.NewStorage(kv.NewMemoryKV(), nil, nil) + store := core.NewStorage(kv.NewMemoryKV()) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", Primary: "zone1", @@ -292,7 +292,7 @@ func (s *testReplicationMode) TestRecoverProgress(c *C) { regionScanBatchSize = 10 regionMinSampleSize = 5 - store := core.NewStorage(kv.NewMemoryKV(), nil, nil) + store := core.NewStorage(kv.NewMemoryKV()) conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", Primary: "zone1", diff --git a/server/schedule/placement/rule_manager_test.go b/server/schedule/placement/rule_manager_test.go index 408afae30c1..b5806a43633 100644 --- a/server/schedule/placement/rule_manager_test.go +++ b/server/schedule/placement/rule_manager_test.go @@ -30,7 +30,7 @@ type testManagerSuite struct { } func (s *testManagerSuite) SetUpTest(c *C) { - s.store = core.NewStorage(kv.NewMemoryKV(), nil, nil) + s.store = core.NewStorage(kv.NewMemoryKV()) var err error s.manager = NewRuleManager(s.store) err = s.manager.Initialize(3, []string{"zone", "rack", "host"}) diff --git a/server/schedulers/balance_test.go b/server/schedulers/balance_test.go index 54283178935..caa61245e1b 100644 --- a/server/schedulers/balance_test.go +++ b/server/schedulers/balance_test.go @@ -185,7 +185,7 @@ func (s *testBalanceLeaderSchedulerSuite) SetUpTest(c *C) { s.opt = config.NewTestOptions() s.tc = mockcluster.NewCluster(s.opt) s.oc = schedule.NewOperatorController(s.ctx, s.tc, nil) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) s.lb = lb } @@ -502,29 +502,28 @@ func (s *testBalanceLeaderRangeSchedulerSuite) TestSingleRangeBalance(c *C) { s.tc.UpdateStoreLeaderWeight(3, 1) s.tc.UpdateStoreLeaderWeight(4, 2) s.tc.AddLeaderRegionWithRange(1, "a", "g", 1, 2, 3, 4) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) ops := lb.Schedule(s.tc) c.Assert(ops, NotNil) c.Assert(ops, HasLen, 1) - c.Assert(ops[0].Counters, HasLen, 3) - c.Assert(ops[0].FinishedCounters, HasLen, 2) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"h", "n"})) + c.Assert(ops[0].Counters, HasLen, 5) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"h", "n"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", "f"})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", "f"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "a"})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "a"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"g", ""})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"g", ""})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "f"})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "f"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) - lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", ""})) + lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"b", ""})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) } @@ -543,7 +542,7 @@ func (s *testBalanceLeaderRangeSchedulerSuite) TestMultiRangeBalance(c *C) { s.tc.UpdateStoreLeaderWeight(3, 1) s.tc.UpdateStoreLeaderWeight(4, 2) s.tc.AddLeaderRegionWithRange(1, "a", "g", 1, 2, 3, 4) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "g", "o", "t"})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", "g", "o", "t"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc)[0].RegionID(), Equals, uint64(1)) s.tc.RemoveRegion(s.tc.GetRegion(1)) @@ -581,7 +580,7 @@ func (s *testBalanceRegionSchedulerSuite) TestBalance(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) opt.SetMaxReplicas(1) @@ -617,7 +616,7 @@ func (s *testBalanceRegionSchedulerSuite) TestReplicas3(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) s.checkReplica3(c, tc, opt, sb) @@ -680,7 +679,7 @@ func (s *testBalanceRegionSchedulerSuite) TestReplicas5(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) s.checkReplica5(c, tc, opt, sb) @@ -772,7 +771,7 @@ func (s *testBalanceRegionSchedulerSuite) TestBalance1(c *C) { core.SetApproximateKeys(200), ) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) tc.AddRegionStore(1, 11) @@ -815,7 +814,7 @@ func (s *testBalanceRegionSchedulerSuite) TestStoreWeight(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) opt.SetMaxReplicas(1) @@ -843,7 +842,7 @@ func (s *testBalanceRegionSchedulerSuite) TestReplacePendingRegion(c *C) { tc.DisableFeature(versioninfo.JointConsensus) oc := schedule.NewOperatorController(s.ctx, nil, nil) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) s.checkReplacePendingRegion(c, tc, opt, sb) @@ -857,7 +856,7 @@ func (s *testBalanceRegionSchedulerSuite) TestOpInfluence(c *C) { tc.DisableFeature(versioninfo.JointConsensus) stream := hbstream.NewTestHeartbeatStreams(s.ctx, tc.ID, tc, false /* no need to run */) oc := schedule.NewOperatorController(s.ctx, tc, stream) - sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) + sb, err := schedule.CreateScheduler(BalanceRegionType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""})) c.Assert(err, IsNil) opt.SetMaxReplicas(1) // Add stores 1,2,3,4. @@ -913,7 +912,7 @@ func (s *testRandomMergeSchedulerSuite) TestMerge(c *C) { stream := hbstream.NewTestHeartbeatStreams(ctx, tc.ID, tc, true /* need to run */) oc := schedule.NewOperatorController(ctx, tc, stream) - mb, err := schedule.CreateScheduler(RandomMergeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(RandomMergeType, []string{"", ""})) + mb, err := schedule.CreateScheduler(RandomMergeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(RandomMergeType, []string{"", ""})) c.Assert(err, IsNil) tc.AddRegionStore(1, 4) @@ -1002,7 +1001,7 @@ func (s *testScatterRangeLeaderSuite) TestBalance(c *C) { } oc := schedule.NewOperatorController(s.ctx, nil, nil) - hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) + hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) c.Assert(err, IsNil) limit := 0 for { @@ -1028,7 +1027,7 @@ func (s *testScatterRangeLeaderSuite) TestConcurrencyUpdateConfig(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) oc := schedule.NewOperatorController(s.ctx, nil, nil) - hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) + hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_50", "t"})) sche := hb.(*scatterRangeScheduler) c.Assert(err, IsNil) ch := make(chan struct{}) @@ -1100,7 +1099,7 @@ func (s *testScatterRangeLeaderSuite) TestBalanceWhenRegionNotHeartbeat(c *C) { } oc := schedule.NewOperatorController(s.ctx, nil, nil) - hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_09", "t"})) + hb, err := schedule.CreateScheduler(ScatterRangeType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ScatterRangeType, []string{"s_00", "s_09", "t"})) c.Assert(err, IsNil) limit := 0 diff --git a/server/schedulers/hot_test.go b/server/schedulers/hot_test.go index 8580fed6f82..aad95f043d4 100644 --- a/server/schedulers/hot_test.go +++ b/server/schedulers/hot_test.go @@ -52,7 +52,7 @@ func (s *testHotSchedulerSuite) TestGCPendingOpInfos(c *C) { tc.PutStoreWithLabels(id) } - sche, err := schedule.CreateScheduler(HotRegionType, schedule.NewOperatorController(ctx, tc, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigJSONDecoder([]byte("null"))) + sche, err := schedule.CreateScheduler(HotRegionType, schedule.NewOperatorController(ctx, tc, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigJSONDecoder([]byte("null"))) c.Assert(err, IsNil) hb := sche.(*hotScheduler) @@ -135,7 +135,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestByteRateOnly(c *C) { tc.SetMaxReplicas(3) tc.SetLocationLabels([]string{"zone", "host"}) tc.DisableFeature(versioninfo.JointConsensus) - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) tc.SetHotRegionCacheHitsThreshold(0) @@ -310,7 +310,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestWithKeyRate(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) hb.(*hotScheduler).conf.SetDstToleranceRatio(1) hb.(*hotScheduler).conf.SetSrcToleranceRatio(1) @@ -364,7 +364,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestLeader(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) tc := mockcluster.NewCluster(opt) @@ -405,7 +405,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestWithPendingInfluence(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) for i := 0; i < 2; i++ { // 0: byte rate @@ -491,7 +491,7 @@ func (s *testHotWriteRegionSchedulerSuite) TestWithRuleEnabled(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) tc.SetEnablePlacementRules(true) - hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotWriteRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) tc.SetHotRegionCacheHitsThreshold(0) key, err := hex.DecodeString("") @@ -567,7 +567,7 @@ func (s *testHotReadRegionSchedulerSuite) TestByteRateOnly(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) tc.DisableFeature(versioninfo.JointConsensus) - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) tc.SetHotRegionCacheHitsThreshold(0) @@ -670,7 +670,7 @@ func (s *testHotReadRegionSchedulerSuite) TestWithKeyRate(c *C) { defer cancel() statistics.Denoising = false opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) hb.(*hotScheduler).conf.SetSrcToleranceRatio(1) hb.(*hotScheduler).conf.SetDstToleranceRatio(1) @@ -722,7 +722,7 @@ func (s *testHotReadRegionSchedulerSuite) TestWithPendingInfluence(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) // For test hb.(*hotScheduler).conf.GreatDecRatio = 0.99 diff --git a/server/schedulers/scheduler_test.go b/server/schedulers/scheduler_test.go index 8bfbcfe5962..1e030126302 100644 --- a/server/schedulers/scheduler_test.go +++ b/server/schedulers/scheduler_test.go @@ -51,7 +51,7 @@ func (s *testShuffleLeaderSuite) TestShuffle(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) - sl, err := schedule.CreateScheduler(ShuffleLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ShuffleLeaderType, []string{"", ""})) + sl, err := schedule.CreateScheduler(ShuffleLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ShuffleLeaderType, []string{"", ""})) c.Assert(err, IsNil) c.Assert(sl.Schedule(tc), IsNil) @@ -93,7 +93,7 @@ func (s *testBalanceAdjacentRegionSuite) TestBalance(c *C) { tc := mockcluster.NewCluster(opt) tc.DisableFeature(versioninfo.JointConsensus) - sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(AdjacentRegionType, []string{"32", "2"})) + sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(AdjacentRegionType, []string{"32", "2"})) c.Assert(err, IsNil) c.Assert(sc.(*balanceAdjacentRegionScheduler).conf.LeaderLimit, Equals, uint64(32)) @@ -161,7 +161,7 @@ func (s *testBalanceAdjacentRegionSuite) TestNoNeedToBalance(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) - sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(AdjacentRegionType, nil)) + sc, err := schedule.CreateScheduler(AdjacentRegionType, schedule.NewOperatorController(s.ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(AdjacentRegionType, nil)) c.Assert(err, IsNil) c.Assert(sc.Schedule(tc), IsNil) @@ -199,7 +199,7 @@ func (s *testRejectLeaderSuite) TestRejectLeader(c *C) { // The label scheduler transfers leader out of store1. oc := schedule.NewOperatorController(ctx, nil, nil) - sl, err := schedule.CreateScheduler(LabelType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(LabelType, []string{"", ""})) + sl, err := schedule.CreateScheduler(LabelType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(LabelType, []string{"", ""})) c.Assert(err, IsNil) op := sl.Schedule(tc) testutil.CheckTransferLeaderFrom(c, op[0], operator.OpLeader, 1) @@ -211,13 +211,13 @@ func (s *testRejectLeaderSuite) TestRejectLeader(c *C) { // As store3 is disconnected, store1 rejects leader. Balancer will not create // any operators. - bs, err := schedule.CreateScheduler(BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + bs, err := schedule.CreateScheduler(BalanceLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) op = bs.Schedule(tc) c.Assert(op, IsNil) // Can't evict leader from store2, neither. - el, err := schedule.CreateScheduler(EvictLeaderType, oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"2"})) + el, err := schedule.CreateScheduler(EvictLeaderType, oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"2"})) c.Assert(err, IsNil) op = el.Schedule(tc) c.Assert(op, IsNil) @@ -248,7 +248,7 @@ func (s *testShuffleHotRegionSchedulerSuite) TestBalance(c *C) { tc.SetMaxReplicas(3) tc.SetLocationLabels([]string{"zone", "host"}) tc.DisableFeature(versioninfo.JointConsensus) - hb, err := schedule.CreateScheduler(ShuffleHotRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder("shuffle-hot-region", []string{"", ""})) + hb, err := schedule.CreateScheduler(ShuffleHotRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder("shuffle-hot-region", []string{"", ""})) c.Assert(err, IsNil) s.checkBalance(c, tc, opt, hb) @@ -307,7 +307,7 @@ func (s *testHotRegionSchedulerSuite) TestAbnormalReplica(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) tc.SetLeaderScheduleLimit(0) - hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), nil) + hb, err := schedule.CreateScheduler(HotReadRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), nil) c.Assert(err, IsNil) tc.AddRegionStore(1, 3) @@ -346,7 +346,7 @@ func (s *testEvictLeaderSuite) TestEvictLeader(c *C) { tc.AddLeaderRegion(2, 2, 1) tc.AddLeaderRegion(3, 3, 1) - sl, err := schedule.CreateScheduler(EvictLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"1"})) + sl, err := schedule.CreateScheduler(EvictLeaderType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(EvictLeaderType, []string{"1"})) c.Assert(err, IsNil) c.Assert(sl.IsScheduleAllowed(tc), IsTrue) op := sl.Schedule(tc) @@ -363,7 +363,7 @@ func (s *testShuffleRegionSuite) TestShuffle(c *C) { opt := config.NewTestOptions() tc := mockcluster.NewCluster(opt) - sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) + sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) c.Assert(err, IsNil) c.Assert(sl.IsScheduleAllowed(tc), IsTrue) c.Assert(sl.Schedule(tc), IsNil) @@ -427,7 +427,7 @@ func (s *testShuffleRegionSuite) TestRole(c *C) { }, peers[0]) tc.PutRegion(region) - sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) + sl, err := schedule.CreateScheduler(ShuffleRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(ShuffleRegionType, []string{"", ""})) c.Assert(err, IsNil) conf := sl.(*shuffleRegionScheduler).conf @@ -449,7 +449,7 @@ func (s *testSpecialUseSuite) TestSpecialUseHotRegion(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() oc := schedule.NewOperatorController(ctx, nil, nil) - storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) + storage := core.NewStorage(kv.NewMemoryKV()) cd := schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""}) bs, err := schedule.CreateScheduler(BalanceRegionType, oc, storage, cd) c.Assert(err, IsNil) @@ -502,7 +502,7 @@ func (s *testSpecialUseSuite) TestSpecialUseReserved(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() oc := schedule.NewOperatorController(ctx, nil, nil) - storage := core.NewStorage(kv.NewMemoryKV(), nil, nil) + storage := core.NewStorage(kv.NewMemoryKV()) cd := schedule.ConfigSliceDecoder(BalanceRegionType, []string{"", ""}) bs, err := schedule.CreateScheduler(BalanceRegionType, oc, storage, cd) c.Assert(err, IsNil) @@ -549,7 +549,7 @@ func (s *testBalanceLeaderSchedulerWithRuleEnabledSuite) SetUpTest(c *C) { s.tc = mockcluster.NewCluster(s.opt) s.tc.SetEnablePlacementRules(true) s.oc = schedule.NewOperatorController(s.ctx, nil, nil) - lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV(), nil, nil), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) + lb, err := schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) c.Assert(err, IsNil) s.lb = lb } diff --git a/server/server.go b/server/server.go index 2134eeefb29..83fe4073e72 100644 --- a/server/server.go +++ b/server/server.go @@ -371,7 +371,11 @@ func (s *Server) startServer(ctx context.Context) error { return err } - s.storage = core.NewStorage(kvBase, regionStorage, encryptionKeyManager) + s.storage = core.NewStorage( + kvBase, + core.WithRegionStorage(regionStorage), + core.WithEncryptionKeyManager(encryptionKeyManager), + ) s.basicCluster = core.NewBasicCluster() s.cluster = cluster.NewRaftCluster(ctx, s.GetClusterRootPath(), s.clusterID, syncer.NewRegionSyncer(s), s.client, s.httpClient) s.hbStreams = hbstream.NewHeartbeatStreams(ctx, s.clusterID, s.cluster) diff --git a/server/statistics/region_collection_test.go b/server/statistics/region_collection_test.go index 0c3b9e26450..53beaf99ca2 100644 --- a/server/statistics/region_collection_test.go +++ b/server/statistics/region_collection_test.go @@ -31,7 +31,7 @@ type testRegionStatisticsSuite struct { } func (t *testRegionStatisticsSuite) SetUpTest(c *C) { - t.store = core.NewStorage(kv.NewMemoryKV(), nil, nil) + t.store = core.NewStorage(kv.NewMemoryKV()) var err error t.manager = placement.NewRuleManager(t.store) err = t.manager.Initialize(3, []string{"zone", "rack", "host"}) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 8abb6c1f727..39071900a54 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -460,7 +460,7 @@ func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) var stores []*metapb.Store id := leaderServer.GetAllocator() for _, addr := range storeAddrs { @@ -607,7 +607,7 @@ func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { // PUT GET failed oldStorage := svr.GetStorage() - svr.SetStorage(core.NewStorage(&testErrorKV{}, nil, nil)) + svr.SetStorage(core.NewStorage(&testErrorKV{})) replicationCfg.MaxReplicas = 7 scheduleCfg.MaxSnapshotCount = 20 pdServerCfg.UseRegionStorage = false @@ -626,7 +626,7 @@ func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { svr.SetStorage(oldStorage) c.Assert(svr.SetReplicationConfig(*replicationCfg), IsNil) - svr.SetStorage(core.NewStorage(&testErrorKV{}, nil, nil)) + svr.SetStorage(core.NewStorage(&testErrorKV{})) c.Assert(svr.DeleteLabelProperty(typ, labelKey, labelValue), NotNil) c.Assert(persistOptions.GetLabelPropertyConfig()[typ][0].Key, Equals, "testKey") @@ -894,7 +894,7 @@ func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) id := leaderServer.GetAllocator() for _, addr := range storeAddrs { storeID, err := id.Alloc() @@ -981,7 +981,7 @@ func (s *clusterTestSuite) TestUpgradeStoreLimit(c *C) { bootstrapCluster(c, clusterID, grpcPDClient, "127.0.0.1:0") rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) store := newMetaStore(1, "127.0.1.1:0", "4.0.0", metapb.StoreState_Up, "test/store1") _, err = putStore(c, grpcPDClient, clusterID, store) c.Assert(err, IsNil) @@ -1039,7 +1039,7 @@ func (s *clusterTestSuite) TestStaleTermHeartbeat(c *C) { storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} rc := leaderServer.GetRaftCluster() c.Assert(rc, NotNil) - rc.SetStorage(core.NewStorage(kv.NewMemoryKV(), nil, nil)) + rc.SetStorage(core.NewStorage(kv.NewMemoryKV())) var peers []*metapb.Peer id := leaderServer.GetAllocator() for _, addr := range storeAddrs { From c35f9487584aec7519d594f9f553fb71929f2a21 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 25 Sep 2020 06:43:37 +0800 Subject: [PATCH 16/37] move loadRegion and saveRegion Signed-off-by: Yi Wu --- server/core/region_storage.go | 38 ----------------------------------- server/core/storage.go | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/server/core/region_storage.go b/server/core/region_storage.go index 3fead555f68..9e5adfb4da2 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -19,7 +19,6 @@ import ( "sync" "time" - "github.com/gogo/protobuf/proto" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/encryption" @@ -131,43 +130,6 @@ func deleteRegion(kv kv.Base, region *metapb.Region) error { return kv.Remove(regionPath(region.GetId())) } -func saveRegion( - kv kv.Base, - encryptionKeyManager *encryptionkm.KeyManager, - region *metapb.Region, -) error { - err := encryption.EncryptRegion(region, encryptionKeyManager) - if err != nil { - return err - } - value, err := proto.Marshal(region) - if err != nil { - return errs.ErrProtoMarshal.Wrap(err).GenWithStackByArgs() - } - return kv.Save(regionPath(region.GetId()), string(value)) -} - -func loadRegion( - kv kv.Base, - encryptionKeyManager *encryptionkm.KeyManager, - regionID uint64, - region *metapb.Region, -) (ok bool, err error) { - value, err := kv.Load(regionPath(regionID)) - if err != nil { - return false, err - } - if value == "" { - return false, nil - } - err = proto.Unmarshal([]byte(value), region) - if err != nil { - return true, errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() - } - err = encryption.DecryptRegion(region, encryptionKeyManager) - return true, err -} - func loadRegions( kv kv.Base, encryptionKeyManager *encryptionkm.KeyManager, diff --git a/server/core/storage.go b/server/core/storage.go index d239bc88168..9582d5647ff 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -27,6 +27,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/tikv/pd/pkg/encryption" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/server/encryptionkm" "github.com/tikv/pd/server/kv" @@ -577,3 +578,40 @@ func saveProto(s kv.Base, key string, msg proto.Message) error { } return s.Save(key, string(value)) } + +func loadRegion( + kv kv.Base, + encryptionKeyManager *encryptionkm.KeyManager, + regionID uint64, + region *metapb.Region, +) (ok bool, err error) { + value, err := kv.Load(regionPath(regionID)) + if err != nil { + return false, err + } + if value == "" { + return false, nil + } + err = proto.Unmarshal([]byte(value), region) + if err != nil { + return true, errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() + } + err = encryption.DecryptRegion(region, encryptionKeyManager) + return true, err +} + +func saveRegion( + kv kv.Base, + encryptionKeyManager *encryptionkm.KeyManager, + region *metapb.Region, +) error { + err := encryption.EncryptRegion(region, encryptionKeyManager) + if err != nil { + return err + } + value, err := proto.Marshal(region) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStackByArgs() + } + return kv.Save(regionPath(region.GetId()), string(value)) +} From 551b3f9e41b2e1d8f5a6e42f1cf456ce7f337876 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 25 Sep 2020 07:03:17 +0800 Subject: [PATCH 17/37] revert changes Signed-off-by: Yi Wu --- pkg/mock/mockcluster/mockcluster.go | 2 +- server/schedulers/balance_test.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/mock/mockcluster/mockcluster.go b/pkg/mock/mockcluster/mockcluster.go index 54c3c786b46..937eeb3b333 100644 --- a/pkg/mock/mockcluster/mockcluster.go +++ b/pkg/mock/mockcluster/mockcluster.go @@ -135,7 +135,7 @@ func (mc *Cluster) AllocPeer(storeID uint64) (*metapb.Peer, error) { func (mc *Cluster) initRuleManager() { if mc.RuleManager == nil { - mc.RuleManager = placement.NewRuleManager(core.NewStorage(kv.NewMemoryKV(), nil, nil)) + mc.RuleManager = placement.NewRuleManager(core.NewStorage(kv.NewMemoryKV())) mc.RuleManager.Initialize(int(mc.GetReplicationConfig().MaxReplicas), mc.GetReplicationConfig().LocationLabels) } } diff --git a/server/schedulers/balance_test.go b/server/schedulers/balance_test.go index b7ec40d57e9..85a4423484d 100644 --- a/server/schedulers/balance_test.go +++ b/server/schedulers/balance_test.go @@ -509,7 +509,8 @@ func (s *testBalanceLeaderRangeSchedulerSuite) TestSingleRangeBalance(c *C) { ops := lb.Schedule(s.tc) c.Assert(ops, NotNil) c.Assert(ops, HasLen, 1) - c.Assert(ops[0].Counters, HasLen, 5) + c.Assert(ops[0].Counters, HasLen, 3) + c.Assert(ops[0].FinishedCounters, HasLen, 2) lb, err = schedule.CreateScheduler(BalanceLeaderType, s.oc, core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder(BalanceLeaderType, []string{"h", "n"})) c.Assert(err, IsNil) c.Assert(lb.Schedule(s.tc), IsNil) From 7051407f4b468163be6b5a9fc9cff44f315e5f33 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 25 Sep 2020 09:28:33 +0800 Subject: [PATCH 18/37] fix doc Signed-off-by: Yi Wu --- server/core/storage.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/core/storage.go b/server/core/storage.go index 9582d5647ff..5de8208ed38 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -78,7 +78,7 @@ func WithRegionStorage(regionStorage *RegionStorage) StorageOption { } } -// WithEncryptionManager sets EncryptionManager to the Storage +// WithEncryptionKeyManager sets EncryptionManager to the Storage func WithEncryptionKeyManager(encryptionKeyManager *encryptionkm.KeyManager) StorageOption { return func(opt *StorageOpt) { opt.encryptionKeyManager = encryptionKeyManager From 7201c37dcc95b1adcc5300e65437bb9c01a20f07 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Mon, 28 Sep 2020 11:41:43 +0800 Subject: [PATCH 19/37] address comment Signed-off-by: Yi Wu --- server/core/region_storage.go | 2 ++ server/core/storage.go | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/server/core/region_storage.go b/server/core/region_storage.go index 9e5adfb4da2..3cf688d47fe 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/gogo/protobuf/proto" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/encryption" @@ -104,6 +105,7 @@ func (s *RegionStorage) backgroundFlush() { // SaveRegion saves one region to storage. func (s *RegionStorage) SaveRegion(region *metapb.Region) error { + region = proto.Clone(region).(*metapb.Region) err := encryption.EncryptRegion(region, s.encryptionKeyManager) if err != nil { return err diff --git a/server/core/storage.go b/server/core/storage.go index 5de8208ed38..1f811994d89 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -289,7 +289,7 @@ func (s *Storage) LoadRuleGroups(f func(k, v string)) error { func (s *Storage) SaveJSON(prefix, key string, data interface{}) error { value, err := json.Marshal(data) if err != nil { - return err + return errs.ErrJSONMarshal.Wrap(err).GenWithStackByArgs() } return s.Save(path.Join(prefix, key), string(value)) } @@ -605,6 +605,7 @@ func saveRegion( encryptionKeyManager *encryptionkm.KeyManager, region *metapb.Region, ) error { + region = proto.Clone(region).(*metapb.Region) err := encryption.EncryptRegion(region, encryptionKeyManager) if err != nil { return err From ee2e0cd5661c3b642d53d46c532cc697dfb7684a Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 2 Oct 2020 12:41:40 +0800 Subject: [PATCH 20/37] key manager Signed-off-by: Yi Wu --- pkg/encryption/config.go | 7 +- pkg/encryption/master_key.go | 8 +- pkg/errs/errno.go | 28 +- server/core/storage.go | 12 +- server/encryptionkm/key_manager.go | 413 +++++++++++++++++++++++++++-- server/server.go | 4 +- 6 files changed, 425 insertions(+), 47 deletions(-) diff --git a/pkg/encryption/config.go b/pkg/encryption/config.go index e690f8bfec5..1ec9a9a85f0 100644 --- a/pkg/encryption/config.go +++ b/pkg/encryption/config.go @@ -62,11 +62,14 @@ func (c *Config) Adjust() error { defaultDataKeyRotationPeriod) } c.DataKeyRotationPeriod.Duration = duration + } else if c.DataKeyRotationPeriod.Duration < 0 { + return errs.ErrEncryptionInvalidConfig.GenWithStack( + "data-key-rotation-period must be greater than 0") } if len(c.MasterKey.Type) == 0 { c.MasterKey.Type = masterKeyTypePlaintext } else { - if _, err := c.GetMasterKey(); err != nil { + if _, err := c.GetMasterKeyMeta(); err != nil { return err } } @@ -91,7 +94,7 @@ func (c *Config) GetMethod() (encryptionpb.EncryptionMethod, error) { } // GetMasterKey gets the master key config. -func (c *Config) GetMasterKey() (*encryptionpb.MasterKey, error) { +func (c *Config) GetMasterKeyMeta() (*encryptionpb.MasterKey, error) { switch c.MasterKey.Type { case masterKeyTypePlaintext: return &encryptionpb.MasterKey{ diff --git a/pkg/encryption/master_key.go b/pkg/encryption/master_key.go index 6c5a7bb1597..952fa443bba 100644 --- a/pkg/encryption/master_key.go +++ b/pkg/encryption/master_key.go @@ -30,8 +30,6 @@ const ( // MasterKey is used to encrypt and decrypt encryption metadata (i.e. data encryption keys). type MasterKey struct { - // Master key config. Used to compare if two master key is the same. - Config *encryptionpb.MasterKey // Encryption key in plaintext. If it is nil, encryption is no-op. // Never output it to info log or persist it on disk. key []byte @@ -45,8 +43,7 @@ func NewMasterKey(config *encryptionpb.MasterKey) (*MasterKey, error) { } if plaintext := config.GetPlaintext(); plaintext != nil { return &MasterKey{ - Config: config, - key: nil, + key: nil, }, nil } if file := config.GetFile(); file != nil { @@ -55,8 +52,7 @@ func NewMasterKey(config *encryptionpb.MasterKey) (*MasterKey, error) { return nil, err } return &MasterKey{ - Config: config, - key: key, + key: key, }, nil } return nil, errors.New("unrecognized master key type") diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index 84e82a0e53d..79e04d562a2 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -267,15 +267,21 @@ var ( // encryption var ( - ErrEncryptionInvalidMethod = errors.Normalize("invalid encryption method", errors.RFCCodeText("PD:encryption:ErrEncryptionInvalidMethod")) - ErrEncryptionInvalidConfig = errors.Normalize("invalid config", errors.RFCCodeText("PD:encryption:ErrEncryptionInvalidConfig")) - ErrEncryptionGenerateIV = errors.Normalize("fail to generate iv", errors.RFCCodeText("PD:encryption:ErrEncryptionGenerateIV")) - ErrEncryptionNewDataKey = errors.Normalize("fail to generate data key", errors.RFCCodeText("PD:encryption:ErrEncryptionNewDataKey")) - ErrEncryptionGCMEncrypt = errors.Normalize("GCM encryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionGCMEncrypt")) - ErrEncryptionGCMDecrypt = errors.Normalize("GCM decryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionGCMDecrypt")) - ErrEncryptionCTREncrypt = errors.Normalize("CTR encryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionCTREncrypt")) - ErrEncryptionCTRDecrypt = errors.Normalize("CTR decryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionCTRDecrypt")) - ErrEncryptionEncryptRegion = errors.Normalize("encrypt region fail", errors.RFCCodeText("PD:encryption:ErrEncryptionEncryptRegion")) - ErrEncryptionDecryptRegion = errors.Normalize("decrypt region fail", errors.RFCCodeText("PD:encryption:ErrEncryptionDecryptRegion")) - ErrEncryptionNewMasterKey = errors.Normalize("fail to get master key", errors.RFCCodeText("PD:encryption:ErrEncryptionNewMasterKey")) + ErrEncryptionInvalidMethod = errors.Normalize("invalid encryption method", errors.RFCCodeText("PD:encryption:ErrEncryptionInvalidMethod")) + ErrEncryptionInvalidConfig = errors.Normalize("invalid config", errors.RFCCodeText("PD:encryption:ErrEncryptionInvalidConfig")) + ErrEncryptionGenerateIV = errors.Normalize("fail to generate iv", errors.RFCCodeText("PD:encryption:ErrEncryptionGenerateIV")) + ErrEncryptionGCMEncrypt = errors.Normalize("GCM encryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionGCMEncrypt")) + ErrEncryptionGCMDecrypt = errors.Normalize("GCM decryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionGCMDecrypt")) + ErrEncryptionCTREncrypt = errors.Normalize("CTR encryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionCTREncrypt")) + ErrEncryptionCTRDecrypt = errors.Normalize("CTR decryption fail", errors.RFCCodeText("PD:encryption:ErrEncryptionCTRDecrypt")) + ErrEncryptionEncryptRegion = errors.Normalize("encrypt region fail", errors.RFCCodeText("PD:encryption:ErrEncryptionEncryptRegion")) + ErrEncryptionDecryptRegion = errors.Normalize("decrypt region fail", errors.RFCCodeText("PD:encryption:ErrEncryptionDecryptRegion")) + ErrEncryptionNewDataKey = errors.Normalize("fail to generate data key", errors.RFCCodeText("PD:encryption:ErrEncryptionNewDataKey")) + ErrEncryptionNewMasterKey = errors.Normalize("fail to get master key", errors.RFCCodeText("PD:encryption:ErrEncryptionNewMasterKey")) + ErrEncryptionCurrentKeyNotFound = errors.Normalize("current data key not found", errors.RFCCodeText("PD:encryption:ErrEncryptionCurrentKeyNotFound")) + ErrEncryptionKeyNotFound = errors.Normalize("data key not found", errors.RFCCodeText("PD:encryption:ErrEncryptionKeyNotFound")) + ErrEncryptionKeysWatcher = errors.Normalize("data key watcher error", errors.RFCCodeText("PD:encryption:ErrEncryptionKeysWatcher")) + ErrEncryptionLoadKeys = errors.Normalize("load data keys error", errors.RFCCodeText("PD:encryption:ErrEncryptionLoadKeys")) + ErrEncryptionRotateDataKey = errors.Normalize("failed to rotate data key", errors.RFCCodeText("PD:encryption:ErrEncryptionRotateDataKey")) + ErrEncryptionSaveDataKeys = errors.Normalize("failed to save data keys", errors.RFCCodeText("PD:encryption:ErrEncryptionSaveDataKeys")) ) diff --git a/server/core/storage.go b/server/core/storage.go index 1f811994d89..1dcdfcb0d88 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -44,7 +44,9 @@ const ( replicationPath = "replication_mode" componentPath = "component" customScheduleConfigPath = "scheduler_config" - encryptionKeysPath = "encryption_keys" + + // Reserved to encryption + encryptionKeysPath = "encryption" ) const ( @@ -134,11 +136,6 @@ func (s *Storage) storeRegionWeightPath(storeID uint64) string { return path.Join(schedulePath, "store_weight", fmt.Sprintf("%020d", storeID), "region") } -// EncryptionKeysPath returns the path to save encryption keys. -func (s *Storage) EncryptionKeysPath() string { - return path.Join(encryptionKeysPath, "keys") -} - // SaveScheduleConfig saves the config of scheduler. func (s *Storage) SaveScheduleConfig(scheduleName string, data []byte) error { configPath := path.Join(customScheduleConfigPath, scheduleName) @@ -438,9 +435,6 @@ func (s *Storage) Close() error { return err } } - if s.encryptionKeyManager != nil { - s.encryptionKeyManager.Close() - } return nil } diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index 945f5348118..bd7483185f6 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -14,41 +14,420 @@ package encryptionkm import ( + "bytes" + "context" + "sync" + "time" + + "github.com/gogo/protobuf/proto" "github.com/pingcap/kvproto/pkg/encryptionpb" - lib "github.com/tikv/pd/pkg/encryption" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/encryption" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/server/election" "github.com/tikv/pd/server/kv" + "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/mvcc/mvccpb" + "go.uber.org/zap" +) + +const ( + // Special key id to denote encryption is currently not enabled. + disableEncryptionKeyID = 0 + // Check interval for data key rotation. + keyRotationCheckPeriod = time.Minute * 10 + // Times to retry generating new data key. + keyRotationRetryLimit = 10 ) // KeyManager maintains the list to encryption keys. It handles encryption key generation and // rotation, persisting and loading encryption keys. -type KeyManager struct{} +type KeyManager struct { + // Backing storage for key dictionary. + etcdClient *clientv3.Client + // Encryption method used to encrypt data + method encryptionpb.EncryptionMethod + // Time interval between data key rotation. + dataKeyRotationPeriod time.Duration + // Metadata defines the master key to use. + masterKeyMeta *encryptionpb.MasterKey + // Encryption config from config file. Only used when current node is PD leader. + config *encryption.Config + // Mutex for updating keys. + muUpdate sync.Mutex + // PD leadership of the current PD node. Only the PD leader will rotate data keys, + // or change current encryption method. Guarded by muUpdate. + leadership *election.Leadership + // ModRevision of the encryption keys data stored in etcd. + // Used to do CAS update to encryption keys. + keysRevision int64 + // Mutex for accessing keys. + mu sync.Mutex + // List of all encryption keys and current encryption key id. Guarded by mu. + keys *encryptionpb.KeyDictionary + // Error hit when loading encryption keys. + keysError error +} + +// EncryptionKeysPath return the path to store key dictionary in etcd. +func EncryptionKeysPath() string { + return "encryption/keys" +} // NewKeyManager creates a new key manager. -func NewKeyManager(kv kv.Base, config *lib.Config) (*KeyManager, error) { - // TODO: Implement - return &KeyManager{}, nil +func NewKeyManager( + ctx context.Context, + etcdClient *clientv3.Client, + config *encryption.Config, +) (*KeyManager, error) { + method, err := config.GetMethod() + if err != nil { + return nil, err + } + masterKeyMeta, err := config.GetMasterKeyMeta() + if err != nil { + return nil, err + } + m := &KeyManager{ + etcdClient: etcdClient, + method: method, + dataKeyRotationPeriod: config.DataKeyRotationPeriod.Duration, + masterKeyMeta: masterKeyMeta, + } + // Load encryption keys from storage. + _, err = m.loadKeys() + if err != nil { + return nil, err + } + // Start periodic check for keys change and rotation key if needed. + go m.startBackgroundLoop(ctx, m.keysRevision) + return m, nil +} + +func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { + // Create new context for the loop. + loopCtx, _ := context.WithCancel(ctx) + // Setup key dictionary watcher + watcher := clientv3.NewWatcher(m.etcdClient) + defer watcher.Close() + watcherCtx, cancel := context.WithCancel(ctx) + defer cancel() + watchChan := watcher.Watch(watcherCtx, EncryptionKeysPath(), clientv3.WithRev(m.keysRevision)) + // Check data key rotation every min(DataKeyRotationPeriod, keyRotationCheckPeriod). + checkPeriod := m.config.DataKeyRotationPeriod.Duration + if keyRotationCheckPeriod < checkPeriod { + checkPeriod = keyRotationCheckPeriod + } + ticker := time.NewTicker(checkPeriod) + defer ticker.Stop() + // Loop + for { + select { + // Reload encryption keys updated by PD leader (could be ourselves). + case resp := <-watchChan: + if resp.Canceled { + log.Warn("encryption key watcher canceled") + m.setKeysError(errs.ErrEncryptionKeysWatcher.GenWithStack("watcher is canceled")) + return + } + for _, event := range resp.Events { + if event.Type != mvccpb.PUT { + m.setKeysError(errs.ErrEncryptionKeysWatcher.GenWithStack("encryption keys deleted")) + return + } + if !bytes.Equal([]byte(EncryptionKeysPath()), event.Kv.Key) { + m.setKeysError(errs.ErrEncryptionKeysWatcher.GenWithStack("encryption keys path not equal")) + return + } + { + m.muUpdate.Lock() + m.loadKeysFromKV(event.Kv) + m.muUpdate.Unlock() + } + } + // Check data key rotation in case we are the PD leader. + case <-ticker.C: + m.muUpdate.Lock() + err := m.rotateKeyIfNeeded(false /*forceUpdate*/) + m.muUpdate.Unlock() + if err != nil { + log.Warn("fail to rotate encryption master key", zap.Error(err)) + } + // Server shutdown. + case <-loopCtx.Done(): + log.Info("encryption key manager is closed") + return + } + } +} + +func (m *KeyManager) setKeysError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.keysError = err +} + +func (m *KeyManager) loadKeysFromKV( + kv *mvccpb.KeyValue, +) (*encryptionpb.KeyDictionary, error) { + content := &encryptionpb.EncryptedContent{} + err := content.Unmarshal(kv.Value) + if err != nil { + return nil, errs.ErrProtoUnmarshal.Wrap(err).GenWithStack( + "fail to unmarshal encrypted encryption keys") + } + masterKeyConfig := content.MasterKey + if masterKeyConfig == nil { + return nil, errs.ErrEncryptionLoadKeys.GenWithStack( + "no master key config found with encryption keys") + } + masterKey, err := encryption.NewMasterKey(masterKeyConfig) + if err != nil { + return nil, err + } + plaintextContent, err := masterKey.Decrypt(content.Content, content.Iv) + if err != nil { + return nil, err + } + keys := &encryptionpb.KeyDictionary{} + err = keys.Unmarshal(plaintextContent) + if err != nil { + return nil, errs.ErrProtoUnmarshal.Wrap(err).GenWithStack( + "fail to unmarshal encryption keys") + } + { + m.mu.Lock() + m.keys = keys + m.mu.Unlock() + } + m.keysRevision = kv.ModRevision + return keys, nil +} + +func (m *KeyManager) loadKeys() (*encryptionpb.KeyDictionary, error) { + resp, err := etcdutil.EtcdKVGet(m.etcdClient, EncryptionKeysPath()) + if err != nil { + return nil, err + } + if resp == nil || len(resp.Kvs) == 0 { + return nil, nil + } + return m.loadKeysFromKV(resp.Kvs[0]) +} + +func (m *KeyManager) saveKeys(keys *encryptionpb.KeyDictionary) error { + // Get master key. + masterKeyMeta, err := m.config.GetMasterKeyMeta() + if err != nil { + return err + } + masterKey, err := encryption.NewMasterKey(masterKeyMeta) + if err != nil { + return err + } + // Set was_exposed flag if master key is plaintext (no-op). + if masterKey.IsPlaintext() { + for _, key := range keys.Keys { + key.WasExposed = true + } + } + // Encode and encrypt data keys. + plaintextContent, err := proto.Marshal(keys) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStack("fail to marshal encrypion keys") + } + ciphertextContent, iv, err := masterKey.Encrypt(plaintextContent) + if err != nil { + return err + } + content := &encryptionpb.EncryptedContent{ + Content: ciphertextContent, + MasterKey: masterKeyMeta, + Iv: iv, + } + value, err := proto.Marshal(content) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStack("fail to marshal encrypted encryption keys") + } + resp, err := kv.NewSlowLogTxn(m.etcdClient). + If(clientv3.Compare(clientv3.ModRevision(EncryptionKeysPath()), "=", m.keysRevision)). + Then(clientv3.OpPut(EncryptionKeysPath(), string(value))). + Commit() + if err != nil { + return errs.ErrEtcdTxn.Wrap(err).GenWithStack("fail to save encryption keys") + } + if !resp.Succeeded { + return errs.ErrEncryptionSaveDataKeys.GenWithStack( + "write conflict, expected revision %d", m.keysRevision) + } + // Leave for the watcher to load the updated keys. + return nil +} + +func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { + if m.leadership == nil || !m.leadership.Check() { + // We are not leader. + m.leadership = nil + return nil + } + var keys *encryptionpb.KeyDictionary + // Make a clone of encryption keys. + { + m.mu.Lock() + keys = proto.Clone(m.keys).(*encryptionpb.KeyDictionary) + m.mu.Unlock() + } + // Initialize if empty. + if keys == nil { + keys = &encryptionpb.KeyDictionary{ + CurrentKeyId: disableEncryptionKeyID, + } + } + if keys.Keys == nil { + keys.Keys = make(map[uint64]*encryptionpb.DataKey) + } + method, err := m.config.GetMethod() + if err != nil { + return err + } + needUpdate := forceUpdate + if method == encryptionpb.EncryptionMethod_PLAINTEXT { + if keys.CurrentKeyId == disableEncryptionKeyID { + // Encryption is not enabled. + return nil + } + keys.CurrentKeyId = disableEncryptionKeyID + needUpdate = true + } else { + needRotate := false + if keys.CurrentKeyId == disableEncryptionKeyID { + needRotate = true + } else { + currentKey := keys.Keys[keys.CurrentKeyId] + if currentKey == nil { + return errs.ErrEncryptionCurrentKeyNotFound.GenWithStack("keyId = %d", keys.CurrentKeyId) + } + // Rotate key in case of: + // * Encryption method is changed. + // * Currnet key is exposed. + // * Current key expired. + if currentKey.Method != method || currentKey.WasExposed || + time.Unix(int64(currentKey.CreationTime), 0). + Add(m.config.DataKeyRotationPeriod.Duration).Before(time.Now()) { + needRotate = true + } + } + if needRotate { + rotated := false + for attempt := 0; attempt < keyRotationRetryLimit; attempt += 1 { + keyID, key, err := encryption.NewDataKey(method) + if err != nil { + return nil + } + if keys.Keys[keyID] == nil { + keys.Keys[keyID] = key + keys.CurrentKeyId = keyID + rotated = true + break + } + // Duplicated key id. retry. + } + if !rotated { + return errs.ErrEncryptionRotateDataKey.GenWithStack("maximum attempts reached") + } + needUpdate = true + } + } + if !needUpdate { + return nil + } + return m.saveKeys(keys) } // GetCurrentKey get the current encryption key. The key is nil if encryption is not enabled. func (m *KeyManager) GetCurrentKey() (keyID uint64, key *encryptionpb.DataKey, err error) { - // TODO: Implement - return 0, nil, nil + if m.method == encryptionpb.EncryptionMethod_PLAINTEXT { + // Encryption is not enabled. + return + } + m.mu.Lock() + defer m.mu.Unlock() + if m.keysError != nil { + return 0, nil, m.keysError + } + if m.keys.CurrentKeyId == disableEncryptionKeyID { + // Encryption is not enabled. + return 0, nil, nil + } + keyID = m.keys.CurrentKeyId + if m.keys == nil { + return 0, nil, errs.ErrEncryptionCurrentKeyNotFound.GenWithStack( + "empty key list, currentKeyID = %d", keyID) + } + key = m.keys.Keys[keyID] + if key == nil { + // Shouldn't happen, unless key dictionary is corrupted. + return 0, nil, errs.ErrEncryptionCurrentKeyNotFound.GenWithStack("currentKeyID = %d", keyID) + } + return } // GetKey get the encryption key with the specific key id. -func (m *KeyManager) GetKey(keyID uint64) (key *encryptionpb.DataKey, err error) { - // TODO: Implement - return nil, nil +func (m *KeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { + localGetKey := func(keyId uint64) (*encryptionpb.DataKey, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.keysError != nil { + return nil, m.keysError + } + return m.keys.Keys[keyId], nil + } + key, err := localGetKey(keyID) + if err != nil { + return nil, err + } + if key != nil { + return key, nil + } + // Key not found in memory. + // The key could be generated by another PD node, which shouldn't happen normally. + m.muUpdate.Lock() + defer m.muUpdate.Unlock() + // Double check, in case keys is updated by watcher or another GetKey call. + key, err = localGetKey(keyID) + if err != nil { + return nil, err + } + if key != nil { + return key, nil + } + // Reload keys from storage. + keys, err := m.loadKeys() + if err != nil { + return nil, err + } + if keys == nil { + key = nil + } else { + key = keys.Keys[keyID] + } + if key == nil { + return nil, errs.ErrEncryptionKeyNotFound.GenWithStack("keyId = %d", keyID) + } + return key, nil } // SetLeadership sets the PD leadership of the current node. PD leader is responsible to update // encryption keys, e.g. key rotation. -func (m *KeyManager) SetLeadership(leadership *election.Leadership) { - // TODO: Implement -} - -// Close close the key manager on PD server shutdown -func (m *KeyManager) Close() { - // TODO: Implement +func (m *KeyManager) SetLeadership(leadership *election.Leadership) error { + m.muUpdate.Lock() + defer m.muUpdate.Unlock() + m.leadership = leadership + // Reload keys just in case we are not up-to-date. + _, err := m.loadKeys() + if err != nil { + return err + } + return m.rotateKeyIfNeeded(true /*forceUpdate*/) } diff --git a/server/server.go b/server/server.go index f09e7d447a5..c117596e316 100644 --- a/server/server.go +++ b/server/server.go @@ -360,12 +360,12 @@ func (s *Server) startServer(ctx context.Context) error { if err = s.tsoAllocatorManager.SetLocalTSOConfig(s.cfg.LocalTSO); err != nil { return err } - kvBase := kv.NewEtcdKVBase(s.client, s.rootPath) - encryptionKeyManager, err := encryptionkm.NewKeyManager(kvBase, &s.cfg.Security.Encryption) + encryptionKeyManager, err := encryptionkm.NewKeyManager(ctx, s.client, &s.cfg.Security.Encryption) if err != nil { return err } s.encryptionKeyManager = encryptionKeyManager + kvBase := kv.NewEtcdKVBase(s.client, s.rootPath) path := filepath.Join(s.cfg.DataDir, "region-meta") regionStorage, err := core.NewRegionStorage(ctx, path, encryptionKeyManager) if err != nil { From 9e753b44146cc37bfa08a7be1c2c64fa3798cc36 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 02:56:52 +0800 Subject: [PATCH 21/37] add test and refactor Signed-off-by: Yi Wu --- pkg/encryption/crypter.go | 4 +- pkg/encryption/crypter_test.go | 5 +- pkg/encryption/master_key_test.go | 3 - server/encryptionkm/key_manager.go | 316 ++++---- server/encryptionkm/key_manager_test.go | 945 ++++++++++++++++++++++++ 5 files changed, 1102 insertions(+), 171 deletions(-) create mode 100644 server/encryptionkm/key_manager_test.go diff --git a/pkg/encryption/crypter.go b/pkg/encryption/crypter.go index d35b75dbff9..54f65539b10 100644 --- a/pkg/encryption/crypter.go +++ b/pkg/encryption/crypter.go @@ -19,7 +19,6 @@ import ( "crypto/rand" "encoding/binary" "io" - "time" "unsafe" "github.com/pingcap/kvproto/pkg/encryptionpb" @@ -100,6 +99,7 @@ func NewIvGCM() (IvGCM, error) { // NewDataKey randomly generate a new data key. func NewDataKey( method encryptionpb.EncryptionMethod, + creationTime uint64, ) (keyID uint64, key *encryptionpb.DataKey, err error) { err = CheckEncryptionMethodSupported(method) if err != nil { @@ -138,7 +138,7 @@ func NewDataKey( key = &encryptionpb.DataKey{ Key: keyBuf, Method: method, - CreationTime: uint64(time.Now().Unix()), + CreationTime: creationTime, WasExposed: false, } return diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index d140117a435..5b5d3bbd987 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -64,13 +64,14 @@ func (s *testCrypterSuite) TestNewIv(c *C) { } func testNewDataKey(c *C, method encryptionpb.EncryptionMethod) { - _, key, err := NewDataKey(method) + _, key, err := NewDataKey(method, uint64(123)) c.Assert(err, IsNil) length, err := KeyLength(method) c.Assert(err, IsNil) - c.Assert(len(key.Key), Equals, length) + c.Assert(key.Key, HasLen, length) c.Assert(key.Method, Equals, method) c.Assert(key.WasExposed, IsFalse) + c.Assert(key.CreationTime, Equals, uint64(123)) } func (s *testCrypterSuite) TestNewDataKey(c *C) { diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index 17619fe2b72..6680d7e3e9f 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -18,7 +18,6 @@ import ( "io/ioutil" "testing" - "github.com/gogo/protobuf/proto" . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/encryptionpb" ) @@ -40,7 +39,6 @@ func (s *testMasterKeySuite) TestPlaintextMasterKey(c *C) { masterKey, err := NewMasterKey(config) c.Assert(err, IsNil) c.Assert(masterKey, Not(IsNil)) - c.Assert(proto.Equal(config, masterKey.Config), IsTrue) c.Assert(len(masterKey.key), Equals, 0) plaintext := "this is a plaintext" @@ -159,6 +157,5 @@ func (s *testMasterKeySuite) TestNewFileMasterKey(c *C) { } masterKey, err := NewMasterKey(config) c.Assert(err, IsNil) - c.Assert(proto.Equal(masterKey.Config, config), IsTrue) c.Assert(hex.EncodeToString(masterKey.key), Equals, key) } diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index bd7483185f6..73e84674975 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -14,9 +14,9 @@ package encryptionkm import ( - "bytes" "context" "sync" + "sync/atomic" "time" "github.com/gogo/protobuf/proto" @@ -26,7 +26,6 @@ import ( "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/server/election" - "github.com/tikv/pd/server/kv" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/mvcc/mvccpb" "go.uber.org/zap" @@ -41,6 +40,14 @@ const ( keyRotationRetryLimit = 10 ) +// Test helpers +var ( + now = func() time.Time { return time.Now() } + tick = func(ticker *time.Ticker) <-chan time.Time { return ticker.C } + eventAfterReloadByWatcher = func() {} + eventAfterTicker = func() {} +) + // KeyManager maintains the list to encryption keys. It handles encryption key generation and // rotation, persisting and loading encryption keys. type KeyManager struct { @@ -52,22 +59,14 @@ type KeyManager struct { dataKeyRotationPeriod time.Duration // Metadata defines the master key to use. masterKeyMeta *encryptionpb.MasterKey - // Encryption config from config file. Only used when current node is PD leader. - config *encryption.Config - // Mutex for updating keys. + // Mutex for updating keys. Used for both of LoadKeys() and rotateKeyIfNeeded(). muUpdate sync.Mutex // PD leadership of the current PD node. Only the PD leader will rotate data keys, // or change current encryption method. Guarded by muUpdate. leadership *election.Leadership - // ModRevision of the encryption keys data stored in etcd. - // Used to do CAS update to encryption keys. - keysRevision int64 - // Mutex for accessing keys. - mu sync.Mutex - // List of all encryption keys and current encryption key id. Guarded by mu. - keys *encryptionpb.KeyDictionary - // Error hit when loading encryption keys. - keysError error + // List of all encryption keys and current encryption key id, + // with type *encryptionpb.KeyDictionary + keys atomic.Value } // EncryptionKeysPath return the path to store key dictionary in etcd. @@ -75,6 +74,85 @@ func EncryptionKeysPath() string { return "encryption/keys" } +// saveKeys saves encryption keys in etcd. Fail if given leadership is not current. +func saveKeys( + etcdClient *clientv3.Client, + leadership *election.Leadership, + masterKeyMeta *encryptionpb.MasterKey, + keys *encryptionpb.KeyDictionary, +) error { + // Get master key. + masterKey, err := encryption.NewMasterKey(masterKeyMeta) + if err != nil { + return err + } + // Set was_exposed flag if master key is plaintext (no-op). + if masterKey.IsPlaintext() { + for _, key := range keys.Keys { + key.WasExposed = true + } + } + // Encode and encrypt data keys. + plaintextContent, err := proto.Marshal(keys) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStack("fail to marshal encrypion keys") + } + ciphertextContent, iv, err := masterKey.Encrypt(plaintextContent) + if err != nil { + return err + } + content := &encryptionpb.EncryptedContent{ + Content: ciphertextContent, + MasterKey: masterKeyMeta, + Iv: iv, + } + value, err := proto.Marshal(content) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStack("fail to marshal encrypted encryption keys") + } + // Avoid write conflict with PD peer by checking if we are leader. + resp, err := leadership.LeaderTxn(). + Then(clientv3.OpPut(EncryptionKeysPath(), string(value))). + Commit() + if err != nil { + return errs.ErrEtcdTxn.Wrap(err).GenWithStack("fail to save encryption keys") + } + if !resp.Succeeded { + return errs.ErrEncryptionSaveDataKeys.GenWithStack("leader expired") + } + // Leave for the watcher to load the updated keys. + return nil +} + +func loadKeysFromKV(kv *mvccpb.KeyValue) (*encryptionpb.KeyDictionary, error) { + content := &encryptionpb.EncryptedContent{} + err := content.Unmarshal(kv.Value) + if err != nil { + return nil, errs.ErrProtoUnmarshal.Wrap(err).GenWithStack( + "fail to unmarshal encrypted encryption keys") + } + masterKeyConfig := content.MasterKey + if masterKeyConfig == nil { + return nil, errs.ErrEncryptionLoadKeys.GenWithStack( + "no master key config found with encryption keys") + } + masterKey, err := encryption.NewMasterKey(masterKeyConfig) + if err != nil { + return nil, err + } + plaintextContent, err := masterKey.Decrypt(content.Content, content.Iv) + if err != nil { + return nil, err + } + keys := &encryptionpb.KeyDictionary{} + err = keys.Unmarshal(plaintextContent) + if err != nil { + return nil, errs.ErrProtoUnmarshal.Wrap(err).GenWithStack( + "fail to unmarshal encryption keys") + } + return keys, nil +} + // NewKeyManager creates a new key manager. func NewKeyManager( ctx context.Context, @@ -96,12 +174,12 @@ func NewKeyManager( masterKeyMeta: masterKeyMeta, } // Load encryption keys from storage. - _, err = m.loadKeys() + _, revision, err := m.loadKeys() if err != nil { return nil, err } // Start periodic check for keys change and rotation key if needed. - go m.startBackgroundLoop(ctx, m.keysRevision) + go m.startBackgroundLoop(ctx, revision) return m, nil } @@ -113,9 +191,9 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { defer watcher.Close() watcherCtx, cancel := context.WithCancel(ctx) defer cancel() - watchChan := watcher.Watch(watcherCtx, EncryptionKeysPath(), clientv3.WithRev(m.keysRevision)) - // Check data key rotation every min(DataKeyRotationPeriod, keyRotationCheckPeriod). - checkPeriod := m.config.DataKeyRotationPeriod.Duration + watchChan := watcher.Watch(watcherCtx, EncryptionKeysPath(), clientv3.WithRev(revision)) + // Check data key rotation every min(dataKeyRotationPeriod, keyRotationCheckPeriod). + checkPeriod := m.dataKeyRotationPeriod if keyRotationCheckPeriod < checkPeriod { checkPeriod = keyRotationCheckPeriod } @@ -127,18 +205,14 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { // Reload encryption keys updated by PD leader (could be ourselves). case resp := <-watchChan: if resp.Canceled { + // If the watcher failed, we rely solely on rotateKeyIfNeeded to reload encryption keys. log.Warn("encryption key watcher canceled") - m.setKeysError(errs.ErrEncryptionKeysWatcher.GenWithStack("watcher is canceled")) - return + continue } for _, event := range resp.Events { if event.Type != mvccpb.PUT { - m.setKeysError(errs.ErrEncryptionKeysWatcher.GenWithStack("encryption keys deleted")) - return - } - if !bytes.Equal([]byte(EncryptionKeysPath()), event.Kv.Key) { - m.setKeysError(errs.ErrEncryptionKeysWatcher.GenWithStack("encryption keys path not equal")) - return + log.Warn("encryption keys is deleted unexpectely") + continue } { m.muUpdate.Lock() @@ -146,14 +220,16 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { m.muUpdate.Unlock() } } + eventAfterReloadByWatcher() // Check data key rotation in case we are the PD leader. - case <-ticker.C: + case <-tick(ticker): m.muUpdate.Lock() err := m.rotateKeyIfNeeded(false /*forceUpdate*/) - m.muUpdate.Unlock() if err != nil { - log.Warn("fail to rotate encryption master key", zap.Error(err)) + log.Warn("fail to rotate data encryption key", zap.Error(err)) } + m.muUpdate.Unlock() + eventAfterTicker() // Server shutdown. case <-loopCtx.Done(): log.Info("encryption key manager is closed") @@ -162,107 +238,31 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { } } -func (m *KeyManager) setKeysError(err error) { - m.mu.Lock() - defer m.mu.Unlock() - m.keysError = err -} - func (m *KeyManager) loadKeysFromKV( kv *mvccpb.KeyValue, ) (*encryptionpb.KeyDictionary, error) { - content := &encryptionpb.EncryptedContent{} - err := content.Unmarshal(kv.Value) - if err != nil { - return nil, errs.ErrProtoUnmarshal.Wrap(err).GenWithStack( - "fail to unmarshal encrypted encryption keys") - } - masterKeyConfig := content.MasterKey - if masterKeyConfig == nil { - return nil, errs.ErrEncryptionLoadKeys.GenWithStack( - "no master key config found with encryption keys") - } - masterKey, err := encryption.NewMasterKey(masterKeyConfig) + keys, err := loadKeysFromKV(kv) if err != nil { return nil, err } - plaintextContent, err := masterKey.Decrypt(content.Content, content.Iv) - if err != nil { - return nil, err - } - keys := &encryptionpb.KeyDictionary{} - err = keys.Unmarshal(plaintextContent) - if err != nil { - return nil, errs.ErrProtoUnmarshal.Wrap(err).GenWithStack( - "fail to unmarshal encryption keys") - } - { - m.mu.Lock() - m.keys = keys - m.mu.Unlock() - } - m.keysRevision = kv.ModRevision + m.keys.Store(keys) + log.Info("reloaded encryption keys", zap.Int64("revision", kv.ModRevision)) return keys, nil } -func (m *KeyManager) loadKeys() (*encryptionpb.KeyDictionary, error) { +func (m *KeyManager) loadKeys() (keys *encryptionpb.KeyDictionary, revision int64, err error) { resp, err := etcdutil.EtcdKVGet(m.etcdClient, EncryptionKeysPath()) if err != nil { - return nil, err + return nil, 0, err } if resp == nil || len(resp.Kvs) == 0 { - return nil, nil + return nil, 0, nil } - return m.loadKeysFromKV(resp.Kvs[0]) -} - -func (m *KeyManager) saveKeys(keys *encryptionpb.KeyDictionary) error { - // Get master key. - masterKeyMeta, err := m.config.GetMasterKeyMeta() - if err != nil { - return err - } - masterKey, err := encryption.NewMasterKey(masterKeyMeta) + keys, err = m.loadKeysFromKV(resp.Kvs[0]) if err != nil { - return err - } - // Set was_exposed flag if master key is plaintext (no-op). - if masterKey.IsPlaintext() { - for _, key := range keys.Keys { - key.WasExposed = true - } - } - // Encode and encrypt data keys. - plaintextContent, err := proto.Marshal(keys) - if err != nil { - return errs.ErrProtoMarshal.Wrap(err).GenWithStack("fail to marshal encrypion keys") - } - ciphertextContent, iv, err := masterKey.Encrypt(plaintextContent) - if err != nil { - return err - } - content := &encryptionpb.EncryptedContent{ - Content: ciphertextContent, - MasterKey: masterKeyMeta, - Iv: iv, + return nil, 0, err } - value, err := proto.Marshal(content) - if err != nil { - return errs.ErrProtoMarshal.Wrap(err).GenWithStack("fail to marshal encrypted encryption keys") - } - resp, err := kv.NewSlowLogTxn(m.etcdClient). - If(clientv3.Compare(clientv3.ModRevision(EncryptionKeysPath()), "=", m.keysRevision)). - Then(clientv3.OpPut(EncryptionKeysPath(), string(value))). - Commit() - if err != nil { - return errs.ErrEtcdTxn.Wrap(err).GenWithStack("fail to save encryption keys") - } - if !resp.Succeeded { - return errs.ErrEncryptionSaveDataKeys.GenWithStack( - "write conflict, expected revision %d", m.keysRevision) - } - // Leave for the watcher to load the updated keys. - return nil + return keys, resp.Kvs[0].ModRevision, err } func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { @@ -271,12 +271,10 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { m.leadership = nil return nil } - var keys *encryptionpb.KeyDictionary - // Make a clone of encryption keys. - { - m.mu.Lock() - keys = proto.Clone(m.keys).(*encryptionpb.KeyDictionary) - m.mu.Unlock() + // Reload encryption keys in case we are not up-to-date. + keys, _, err := m.loadKeys() + if err != nil { + return err } // Initialize if empty. if keys == nil { @@ -287,12 +285,8 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { if keys.Keys == nil { keys.Keys = make(map[uint64]*encryptionpb.DataKey) } - method, err := m.config.GetMethod() - if err != nil { - return err - } needUpdate := forceUpdate - if method == encryptionpb.EncryptionMethod_PLAINTEXT { + if m.method == encryptionpb.EncryptionMethod_PLAINTEXT { if keys.CurrentKeyId == disableEncryptionKeyID { // Encryption is not enabled. return nil @@ -312,16 +306,16 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { // * Encryption method is changed. // * Currnet key is exposed. // * Current key expired. - if currentKey.Method != method || currentKey.WasExposed || + if currentKey.Method != m.method || currentKey.WasExposed || time.Unix(int64(currentKey.CreationTime), 0). - Add(m.config.DataKeyRotationPeriod.Duration).Before(time.Now()) { + Add(m.dataKeyRotationPeriod).Before(now()) { needRotate = true } } if needRotate { rotated := false for attempt := 0; attempt < keyRotationRetryLimit; attempt += 1 { - keyID, key, err := encryption.NewDataKey(method) + keyID, key, err := encryption.NewDataKey(m.method, uint64(now().Unix())) if err != nil { return nil } @@ -342,51 +336,53 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { if !needUpdate { return nil } - return m.saveKeys(keys) + err = saveKeys(m.etcdClient, m.leadership, m.masterKeyMeta, keys) + if err != nil { + return err + } + // Update local keys. + m.keys.Store(keys) + return err +} + +func (m *KeyManager) getKeys() *encryptionpb.KeyDictionary { + keys := m.keys.Load() + if keys == nil { + return nil + } + return keys.(*encryptionpb.KeyDictionary) } // GetCurrentKey get the current encryption key. The key is nil if encryption is not enabled. func (m *KeyManager) GetCurrentKey() (keyID uint64, key *encryptionpb.DataKey, err error) { - if m.method == encryptionpb.EncryptionMethod_PLAINTEXT { - // Encryption is not enabled. - return - } - m.mu.Lock() - defer m.mu.Unlock() - if m.keysError != nil { - return 0, nil, m.keysError - } - if m.keys.CurrentKeyId == disableEncryptionKeyID { + keys := m.getKeys() + if keys == nil || keys.CurrentKeyId == disableEncryptionKeyID { // Encryption is not enabled. return 0, nil, nil } - keyID = m.keys.CurrentKeyId - if m.keys == nil { + keyID = keys.CurrentKeyId + if keys.Keys == nil { return 0, nil, errs.ErrEncryptionCurrentKeyNotFound.GenWithStack( "empty key list, currentKeyID = %d", keyID) } - key = m.keys.Keys[keyID] + key = keys.Keys[keyID] if key == nil { // Shouldn't happen, unless key dictionary is corrupted. return 0, nil, errs.ErrEncryptionCurrentKeyNotFound.GenWithStack("currentKeyID = %d", keyID) } - return + return keyID, key, nil } // GetKey get the encryption key with the specific key id. func (m *KeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { - localGetKey := func(keyId uint64) (*encryptionpb.DataKey, error) { - m.mu.Lock() - defer m.mu.Unlock() - if m.keysError != nil { - return nil, m.keysError + localGetKey := func(keyId uint64) *encryptionpb.DataKey { + keys := m.getKeys() + if keys == nil || keys.Keys == nil { + return nil } - return m.keys.Keys[keyId], nil - } - key, err := localGetKey(keyID) - if err != nil { - return nil, err + return keys.Keys[keyId] } + key := localGetKey(keyID) if key != nil { return key, nil } @@ -395,15 +391,12 @@ func (m *KeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { m.muUpdate.Lock() defer m.muUpdate.Unlock() // Double check, in case keys is updated by watcher or another GetKey call. - key, err = localGetKey(keyID) - if err != nil { - return nil, err - } + key = localGetKey(keyID) if key != nil { return key, nil } // Reload keys from storage. - keys, err := m.loadKeys() + keys, _, err := m.loadKeys() if err != nil { return nil, err } @@ -424,10 +417,5 @@ func (m *KeyManager) SetLeadership(leadership *election.Leadership) error { m.muUpdate.Lock() defer m.muUpdate.Unlock() m.leadership = leadership - // Reload keys just in case we are not up-to-date. - _, err := m.loadKeys() - if err != nil { - return err - } return m.rotateKeyIfNeeded(true /*forceUpdate*/) } diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go new file mode 100644 index 00000000000..0ab1f5831d3 --- /dev/null +++ b/server/encryptionkm/key_manager_test.go @@ -0,0 +1,945 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package encryptionkm + +import ( + "context" + "encoding/hex" + "fmt" + "io/ioutil" + "net/url" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/gogo/protobuf/proto" + . "github.com/pingcap/check" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/tikv/pd/pkg/encryption" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/pkg/tempurl" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server/election" + "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/embed" +) + +func TestKeyManager(t *testing.T) { + TestingT(t) +} + +type testKeyManagerSuite struct{} + +var _ = Suite(&testKeyManagerSuite{}) + +const ( + testMasterKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b530" + testMasterKey2 = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b531" + testDataKey = "be798242dde0c40d9a65cdbc36c1c9ac" +) + +func getTestDataKey() []byte { + key, _ := hex.DecodeString(testDataKey) + return key +} + +func newTestEtcd(c *C) (client *clientv3.Client, cleanup func()) { + cfg := embed.NewConfig() + cfg.Name = "test_etcd" + cfg.Dir, _ = ioutil.TempDir("/tmp", "test_etcd") + cfg.Logger = "zap" + pu, err := url.Parse(tempurl.Alloc()) + c.Assert(err, IsNil) + cfg.LPUrls = []url.URL{*pu} + cfg.APUrls = cfg.LPUrls + cu, err := url.Parse(tempurl.Alloc()) + c.Assert(err, IsNil) + cfg.LCUrls = []url.URL{*cu} + cfg.ACUrls = cfg.LCUrls + cfg.InitialCluster = fmt.Sprintf("%s=%s", cfg.Name, &cfg.LPUrls[0]) + cfg.ClusterState = embed.ClusterStateFlagNew + server, err := embed.StartEtcd(cfg) + c.Assert(err, IsNil) + <-server.Server.ReadyNotify() + + client, err = clientv3.New(clientv3.Config{ + Endpoints: []string{cfg.LCUrls[0].String()}, + }) + c.Assert(err, IsNil) + + cleanup = func() { + client.Close() + server.Close() + os.RemoveAll(cfg.Dir) + } + + return client, cleanup +} + +func newTestKeyFile(c *C, key ...string) (keyFilePath string, cleanup func()) { + testKey := testMasterKey + for _, k := range key { + testKey = k + } + tempDir, err := ioutil.TempDir("/tmp", "test_key_file") + c.Assert(err, IsNil) + keyFilePath = tempDir + "/key" + err = ioutil.WriteFile(keyFilePath, []byte(testKey), os.ModeAppend) + c.Assert(err, IsNil) + + cleanup = func() { + os.RemoveAll(tempDir) + } + + return keyFilePath, cleanup +} + +func newTestLeader(c *C, client *clientv3.Client) *election.Leadership { + leader := election.NewLeadership(client, "test_leader", "test") + timeout := int64(30000000) // about a year. + err := leader.Campaign(timeout, "") + c.Assert(err, IsNil) + return leader +} + +func checkMasterKeyMeta(c *C, value []byte, meta *encryptionpb.MasterKey) { + content := &encryptionpb.EncryptedContent{} + err := content.Unmarshal(value) + c.Assert(err, IsNil) + c.Assert(proto.Equal(content.MasterKey, meta), IsTrue) +} + +func (s *testKeyManagerSuite) TestNewKeyManagerBasic(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + // Cancel background loop. + cancel() + // Use default config. + config := &encryption.Config{} + err := config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + // Check config. + c.Assert(m.method, Equals, encryptionpb.EncryptionMethod_PLAINTEXT) + c.Assert(m.masterKeyMeta.GetPlaintext(), NotNil) + // Check loaded keys. + c.Assert(m.keys.Load(), IsNil) + // Check etcd KV. + value, err := etcdutil.GetValue(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + c.Assert(value, IsNil) +} + +func (s *testKeyManagerSuite) TestNewKeyManagerWithCustomConfig(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + // Cancel background loop. + cancel() + // Custom config + rotatePeriod, err := time.ParseDuration("100h") + c.Assert(err, IsNil) + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + DataKeyRotationPeriod: typeutil.NewDuration(rotatePeriod), + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + // Check config. + c.Assert(m.method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) + c.Assert(m.dataKeyRotationPeriod, Equals, rotatePeriod) + c.Assert(m.masterKeyMeta, NotNil) + keyFileMeta := m.masterKeyMeta.GetFile() + c.Assert(keyFileMeta, NotNil) + c.Assert(keyFileMeta.Path, Equals, config.MasterKey.MasterKeyFileConfig.FilePath) + // Check loaded keys. + c.Assert(m.keys.Load(), IsNil) + // Check etcd KV. + value, err := etcdutil.GetValue(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + c.Assert(value, IsNil) +} + +func (s *testKeyManagerSuite) TestNewKeyManagerLoadKeys(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Cancel background loop. + cancel() + // Use default config. + config := &encryption.Config{} + err := config.Adjust() + c.Assert(err, IsNil) + // Store initial keys in etcd. + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: true, + }, + }, + } + err = saveKeys(client, leadership, masterKeyMeta, keys) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + // Check config. + c.Assert(m.method, Equals, encryptionpb.EncryptionMethod_PLAINTEXT) + c.Assert(m.masterKeyMeta.GetPlaintext(), NotNil) + // Check loaded keys. + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Check etcd KV. + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, keys), IsTrue) +} + +func (s *testKeyManagerSuite) TestGetCurrentKey(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + // Cancel background loop. + cancel() + // Use default config. + config := &encryption.Config{} + err := config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + // Test encryption disabled. + currentKeyID, currentKey, err := m.GetCurrentKey() + c.Assert(err, IsNil) + c.Assert(currentKeyID, Equals, uint64(disableEncryptionKeyID)) + c.Assert(currentKey, IsNil) + // Test normal case. + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: true, + }, + }, + } + m.keys.Store(keys) + currentKeyID, currentKey, err = m.GetCurrentKey() + c.Assert(err, IsNil) + c.Assert(currentKeyID, Equals, keys.CurrentKeyId) + c.Assert(proto.Equal(currentKey, keys.Keys[keys.CurrentKeyId]), IsTrue) + // Test current key missing. + keys = &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: make(map[uint64]*encryptionpb.DataKey), + } + m.keys.Store(keys) + currentKeyID, currentKey, err = m.GetCurrentKey() + c.Assert(err, NotNil) +} + +func (s *testKeyManagerSuite) TestGetKey(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Cancel background loop. + cancel() + // Store initial keys in etcd. + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: true, + }, + 456: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679534), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + // Use default config. + config := &encryption.Config{} + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + // Get existing key. + key, err := m.GetKey(uint64(123)) + c.Assert(err, IsNil) + c.Assert(proto.Equal(key, keys.Keys[123]), IsTrue) + // Cancel background loop. + cancel() + // Get key that require a reload. + // Deliberately cancel watcher, delete a key and check if it has reloaded. + loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) + loadedKeys = proto.Clone(loadedKeys).(*encryptionpb.KeyDictionary) + delete(loadedKeys.Keys, 456) + m.keys.Store(loadedKeys) + key, err = m.GetKey(uint64(456)) + c.Assert(err, IsNil) + c.Assert(proto.Equal(key, keys.Keys[456]), IsTrue) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Get non-existing key. + key, err = m.GetKey(uint64(789)) + c.Assert(err, NotNil) +} + +func (s *testKeyManagerSuite) TestWatcher(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReloadByWatcher = func() { + var e struct{} + reloadEvent <- e + } + defer func() { eventAfterReloadByWatcher = func() {} }() + // Use default config. + config := &encryption.Config{} + err := config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + key, err := m.GetKey(123) + c.Assert(err, NotNil) + key, err = m.GetKey(456) + c.Assert(err, NotNil) + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: true, + }, + }, + } + err = saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + <-reloadEvent + key, err = m.GetKey(123) + c.Assert(err, IsNil) + c.Assert(proto.Equal(key, keys.Keys[123]), IsTrue) + key, err = m.GetKey(456) + c.Assert(err, NotNil) + // Update again + keys = &encryptionpb.KeyDictionary{ + CurrentKeyId: 456, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: true, + }, + 456: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679534), + WasExposed: false, + }, + }, + } + err = saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + <-reloadEvent + key, err = m.GetKey(123) + c.Assert(err, IsNil) + c.Assert(proto.Equal(key, keys.Keys[123]), IsTrue) + key, err = m.GetKey(456) + c.Assert(err, IsNil) + c.Assert(proto.Equal(key, keys.Keys[456]), IsTrue) +} + +func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionOff(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + // Cancel background loop. + cancel() + // Use default config. + config := &encryption.Config{} + err := config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(m.keys.Load(), IsNil) + // Set leadership + leadership := newTestLeader(c, client) + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check encryption stays off. + c.Assert(m.keys.Load(), IsNil) + value, err := etcdutil.GetValue(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + c.Assert(value, IsNil) +} + +func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionEnabling(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Cancel background loop. + cancel() + // Config with encryption on. + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err := config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(m.keys.Load(), IsNil) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check encryption is on and persisted. + c.Assert(m.keys.Load(), NotNil) + currentKeyID, currentKey, err := m.GetCurrentKey() + c.Assert(err, IsNil) + method, err := config.GetMethod() + c.Assert(err, IsNil) + c.Assert(currentKey.Method, Equals, method) + loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) + c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) +} + +func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Mock time + originalNow := now + now = func() time.Time { return time.Unix(int64(1601679533), 0) } + defer func() { now = originalNow }() + // Cancel background loop. + cancel() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with different encrption method. + config := &encryption.Config{ + DataEncryptionMethod: "aes256-ctr", + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check encryption method is updated. + c.Assert(m.keys.Load(), NotNil) + currentKeyID, currentKey, err := m.GetCurrentKey() + c.Assert(err, IsNil) + c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES256_CTR) + c.Assert(currentKey.Key, HasLen, 32) + loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) + c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) + c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) +} + +func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Mock time + originalNow := now + now = func() time.Time { return time.Unix(int64(1601679533), 0) } + defer func() { now = originalNow }() + // Cancel background loop. + cancel() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: true, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with different encrption method. + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check encryption method is updated. + c.Assert(m.keys.Load(), NotNil) + currentKeyID, currentKey, err := m.GetCurrentKey() + c.Assert(err, IsNil) + c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) + c.Assert(currentKey.Key, HasLen, 16) + c.Assert(currentKey.WasExposed, IsFalse) + loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) + c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) + c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) +} + +func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Mock time + originalNow := now + now = func() time.Time { return time.Unix(int64(1601679533+101), 0) } + defer func() { now = originalNow }() + // Cancel background loop. + cancel() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with 100s rotation period. + rotationPeriod, err := time.ParseDuration("100s") + c.Assert(err, IsNil) + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + DataKeyRotationPeriod: typeutil.NewDuration(rotationPeriod), + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check encryption method is updated. + c.Assert(m.keys.Load(), NotNil) + currentKeyID, currentKey, err := m.GetCurrentKey() + c.Assert(err, IsNil) + c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) + c.Assert(currentKey.Key, HasLen, 16) + c.Assert(currentKey.WasExposed, IsFalse) + c.Assert(currentKey.CreationTime, Equals, uint64(now().Unix())) + loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) + c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) + c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) +} + +func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + keyFile2, cleanupKeyFile2 := newTestKeyFile(c, testMasterKey2) + defer cleanupKeyFile2() + leadership := newTestLeader(c, client) + // Mock time + originalNow := now + now = func() time.Time { return time.Unix(int64(1601679533), 0) } + defer func() { now = originalNow }() + // Cancel background loop. + cancel() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with a different master key. + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile2, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check keys are the same, but encrypted with the new master key. + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, keys), IsTrue) + meta, err := config.GetMasterKeyMeta() + c.Assert(err, IsNil) + checkMasterKeyMeta(c, resp.Kvs[0].Value, meta) +} + +func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Cancel background loop. + cancel() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Use default config. + config := &encryption.Config{} + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check encryption is disabled + expectedKeys := proto.Clone(keys).(*encryptionpb.KeyDictionary) + expectedKeys.CurrentKeyId = disableEncryptionKeyID + expectedKeys.Keys[123].WasExposed = true + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), expectedKeys), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, expectedKeys), IsTrue) +} + +func (s *testKeyManagerSuite) TestKeyRotation(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Mock time + originalNow := now + mockNow := int64(1601679533) + now = func() time.Time { return time.Unix(atomic.LoadInt64(&mockNow), 0) } + defer func() { now = originalNow }() + originalTick := tick + mockTick := make(chan time.Time) + tick = func(ticker *time.Ticker) <-chan time.Time { return mockTick } + defer func() { tick = originalTick }() + // Listen on ticker event + tickerEvent := make(chan struct{}, 1) + eventAfterTicker = func() { + var e struct{} + tickerEvent <- e + } + defer func() { eventAfterTicker = func() {} }() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with 100s rotation period. + rotationPeriod, err := time.ParseDuration("100s") + c.Assert(err, IsNil) + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + DataKeyRotationPeriod: typeutil.NewDuration(rotationPeriod), + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check keys + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, keys), IsTrue) + // Advance time and trigger ticker + atomic.AddInt64(&mockNow, int64(101)) + mockTick <- time.Unix(atomic.LoadInt64(&mockNow), 0) + <-tickerEvent + // Check key is rotated. + currentKeyID, currentKey, err := m.GetCurrentKey() + c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) + c.Assert(currentKey.Key, HasLen, 16) + c.Assert(currentKey.CreationTime, Equals, uint64(mockNow)) + c.Assert(currentKey.WasExposed, IsFalse) + loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) + c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) + c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) + c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) + resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + c.Assert(err, IsNil) + storedKeys, err = loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, loadedKeys), IsTrue) +} From a6d13595617082846f15f75b8cb9f22f97b2e399 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 03:29:28 +0800 Subject: [PATCH 22/37] make EncryptionKeysPath a const Signed-off-by: Yi Wu --- server/core/storage.go | 3 - server/encryptionkm/key_manager.go | 28 ++++---- server/encryptionkm/key_manager_test.go | 95 ++++++++++++++++++------- 3 files changed, 82 insertions(+), 44 deletions(-) diff --git a/server/core/storage.go b/server/core/storage.go index 1dcdfcb0d88..018d80a04bf 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -44,9 +44,6 @@ const ( replicationPath = "replication_mode" componentPath = "component" customScheduleConfigPath = "scheduler_config" - - // Reserved to encryption - encryptionKeysPath = "encryption" ) const ( diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index 73e84674975..50cc7b9f48c 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -32,6 +32,9 @@ import ( ) const ( + // EncryptionKeysPath is the path to store keys in etcd. + EncryptionKeysPath = "encryption_keys" + // Special key id to denote encryption is currently not enabled. disableEncryptionKeyID = 0 // Check interval for data key rotation. @@ -42,10 +45,10 @@ const ( // Test helpers var ( - now = func() time.Time { return time.Now() } - tick = func(ticker *time.Ticker) <-chan time.Time { return ticker.C } - eventAfterReloadByWatcher = func() {} - eventAfterTicker = func() {} + now = func() time.Time { return time.Now() } + tick = func(ticker *time.Ticker) <-chan time.Time { return ticker.C } + eventAfterReload = func() {} + eventAfterTicker = func() {} ) // KeyManager maintains the list to encryption keys. It handles encryption key generation and @@ -69,11 +72,6 @@ type KeyManager struct { keys atomic.Value } -// EncryptionKeysPath return the path to store key dictionary in etcd. -func EncryptionKeysPath() string { - return "encryption/keys" -} - // saveKeys saves encryption keys in etcd. Fail if given leadership is not current. func saveKeys( etcdClient *clientv3.Client, @@ -112,7 +110,7 @@ func saveKeys( } // Avoid write conflict with PD peer by checking if we are leader. resp, err := leadership.LeaderTxn(). - Then(clientv3.OpPut(EncryptionKeysPath(), string(value))). + Then(clientv3.OpPut(EncryptionKeysPath, string(value))). Commit() if err != nil { return errs.ErrEtcdTxn.Wrap(err).GenWithStack("fail to save encryption keys") @@ -191,7 +189,7 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { defer watcher.Close() watcherCtx, cancel := context.WithCancel(ctx) defer cancel() - watchChan := watcher.Watch(watcherCtx, EncryptionKeysPath(), clientv3.WithRev(revision)) + watchChan := watcher.Watch(watcherCtx, EncryptionKeysPath, clientv3.WithRev(revision)) // Check data key rotation every min(dataKeyRotationPeriod, keyRotationCheckPeriod). checkPeriod := m.dataKeyRotationPeriod if keyRotationCheckPeriod < checkPeriod { @@ -220,7 +218,7 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { m.muUpdate.Unlock() } } - eventAfterReloadByWatcher() + eventAfterReload() // Check data key rotation in case we are the PD leader. case <-tick(ticker): m.muUpdate.Lock() @@ -251,7 +249,7 @@ func (m *KeyManager) loadKeysFromKV( } func (m *KeyManager) loadKeys() (keys *encryptionpb.KeyDictionary, revision int64, err error) { - resp, err := etcdutil.EtcdKVGet(m.etcdClient, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(m.etcdClient, EncryptionKeysPath) if err != nil { return nil, 0, err } @@ -336,12 +334,12 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { if !needUpdate { return nil } + // Store updated keys in etcd. err = saveKeys(m.etcdClient, m.leadership, m.masterKeyMeta, keys) if err != nil { return err } - // Update local keys. - m.keys.Store(keys) + // m.keys is not updated immediately. Defer to have watcher reload keys. return err } diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index 0ab1f5831d3..5a40ca3ac2d 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -142,7 +142,7 @@ func (s *testKeyManagerSuite) TestNewKeyManagerBasic(c *C) { // Check loaded keys. c.Assert(m.keys.Load(), IsNil) // Check etcd KV. - value, err := etcdutil.GetValue(client, EncryptionKeysPath()) + value, err := etcdutil.GetValue(client, EncryptionKeysPath) c.Assert(err, IsNil) c.Assert(value, IsNil) } @@ -185,7 +185,7 @@ func (s *testKeyManagerSuite) TestNewKeyManagerWithCustomConfig(c *C) { // Check loaded keys. c.Assert(m.keys.Load(), IsNil) // Check etcd KV. - value, err := etcdutil.GetValue(client, EncryptionKeysPath()) + value, err := etcdutil.GetValue(client, EncryptionKeysPath) c.Assert(err, IsNil) c.Assert(value, IsNil) } @@ -234,7 +234,7 @@ func (s *testKeyManagerSuite) TestNewKeyManagerLoadKeys(c *C) { // Check loaded keys. c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) // Check etcd KV. - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(err, IsNil) @@ -364,11 +364,11 @@ func (s *testKeyManagerSuite) TestWatcher(c *C) { leadership := newTestLeader(c, client) // Listen on watcher event reloadEvent := make(chan struct{}, 1) - eventAfterReloadByWatcher = func() { + eventAfterReload = func() { var e struct{} reloadEvent <- e } - defer func() { eventAfterReloadByWatcher = func() {} }() + defer func() { eventAfterReload = func() {} }() // Use default config. config := &encryption.Config{} err := config.Adjust() @@ -458,7 +458,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionOff(c *C) { c.Assert(err, IsNil) // Check encryption stays off. c.Assert(m.keys.Load(), IsNil) - value, err := etcdutil.GetValue(client, EncryptionKeysPath()) + value, err := etcdutil.GetValue(client, EncryptionKeysPath) c.Assert(err, IsNil) c.Assert(value, IsNil) } @@ -472,8 +472,13 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionEnabling(c *C) { keyFile, cleanupKeyFile := newTestKeyFile(c) defer cleanupKeyFile() leadership := newTestLeader(c, client) - // Cancel background loop. - cancel() + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReload = func() { + var e struct{} + reloadEvent <- e + } + defer func() { eventAfterReload = func() {} }() // Config with encryption on. config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", @@ -494,6 +499,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionEnabling(c *C) { err = m.SetLeadership(leadership) c.Assert(err, IsNil) // Check encryption is on and persisted. + <-reloadEvent c.Assert(m.keys.Load(), NotNil) currentKeyID, currentKey, err := m.GetCurrentKey() c.Assert(err, IsNil) @@ -502,7 +508,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionEnabling(c *C) { c.Assert(currentKey.Method, Equals, method) loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) @@ -521,8 +527,12 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) originalNow := now now = func() time.Time { return time.Unix(int64(1601679533), 0) } defer func() { now = originalNow }() - // Cancel background loop. - cancel() + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReload = func() { + var e struct{} + reloadEvent <- e + } // Update keys in etcd masterKeyMeta := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -564,6 +574,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) err = m.SetLeadership(leadership) c.Assert(err, IsNil) // Check encryption method is updated. + <-reloadEvent c.Assert(m.keys.Load(), NotNil) currentKeyID, currentKey, err := m.GetCurrentKey() c.Assert(err, IsNil) @@ -572,7 +583,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) @@ -591,8 +602,13 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { originalNow := now now = func() time.Time { return time.Unix(int64(1601679533), 0) } defer func() { now = originalNow }() - // Cancel background loop. - cancel() + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReload = func() { + var e struct{} + reloadEvent <- e + } + defer func() { eventAfterReload = func() {} }() // Update keys in etcd masterKeyMeta := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -634,6 +650,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { err = m.SetLeadership(leadership) c.Assert(err, IsNil) // Check encryption method is updated. + <-reloadEvent c.Assert(m.keys.Load(), NotNil) currentKeyID, currentKey, err := m.GetCurrentKey() c.Assert(err, IsNil) @@ -643,7 +660,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) @@ -662,8 +679,13 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { originalNow := now now = func() time.Time { return time.Unix(int64(1601679533+101), 0) } defer func() { now = originalNow }() - // Cancel background loop. - cancel() + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReload = func() { + var e struct{} + reloadEvent <- e + } + defer func() { eventAfterReload = func() {} }() // Update keys in etcd masterKeyMeta := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -708,6 +730,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { err = m.SetLeadership(leadership) c.Assert(err, IsNil) // Check encryption method is updated. + <-reloadEvent c.Assert(m.keys.Load(), NotNil) currentKeyID, currentKey, err := m.GetCurrentKey() c.Assert(err, IsNil) @@ -718,7 +741,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) @@ -739,8 +762,13 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { originalNow := now now = func() time.Time { return time.Unix(int64(1601679533), 0) } defer func() { now = originalNow }() - // Cancel background loop. - cancel() + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReload = func() { + var e struct{} + reloadEvent <- e + } + defer func() { eventAfterReload = func() {} }() // Update keys in etcd masterKeyMeta := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -782,8 +810,9 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { err = m.SetLeadership(leadership) c.Assert(err, IsNil) // Check keys are the same, but encrypted with the new master key. + <-reloadEvent c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(err, IsNil) @@ -802,8 +831,13 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { keyFile, cleanupKeyFile := newTestKeyFile(c) defer cleanupKeyFile() leadership := newTestLeader(c, client) - // Cancel background loop. - cancel() + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReload = func() { + var e struct{} + reloadEvent <- e + } + defer func() { eventAfterReload = func() {} }() // Update keys in etcd masterKeyMeta := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -837,11 +871,12 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { err = m.SetLeadership(leadership) c.Assert(err, IsNil) // Check encryption is disabled + <-reloadEvent expectedKeys := proto.Clone(keys).(*encryptionpb.KeyDictionary) expectedKeys.CurrentKeyId = disableEncryptionKeyID expectedKeys.Keys[123].WasExposed = true c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), expectedKeys), IsTrue) - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(err, IsNil) @@ -866,6 +901,13 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { mockTick := make(chan time.Time) tick = func(ticker *time.Ticker) <-chan time.Time { return mockTick } defer func() { tick = originalTick }() + // Listen on watcher event + reloadEvent := make(chan struct{}, 1) + eventAfterReload = func() { + var e struct{} + reloadEvent <- e + } + defer func() { eventAfterReload = func() {} }() // Listen on ticker event tickerEvent := make(chan struct{}, 1) eventAfterTicker = func() { @@ -918,7 +960,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { c.Assert(err, IsNil) // Check keys c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) - resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err := loadKeysFromKV(resp.Kvs[0]) c.Assert(err, IsNil) @@ -927,6 +969,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { atomic.AddInt64(&mockNow, int64(101)) mockTick <- time.Unix(atomic.LoadInt64(&mockNow), 0) <-tickerEvent + <-reloadEvent // Check key is rotated. currentKeyID, currentKey, err := m.GetCurrentKey() c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) @@ -937,7 +980,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) - resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath()) + resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) storedKeys, err = loadKeysFromKV(resp.Kvs[0]) c.Assert(err, IsNil) From 103a61251558f8e71ad25d26963cc67504cc76a3 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 03:31:48 +0800 Subject: [PATCH 23/37] fix region_crypter key manager nil check Signed-off-by: Yi Wu --- pkg/encryption/region_crypter.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/encryption/region_crypter.go b/pkg/encryption/region_crypter.go index bd379d4b92a..cf1069746c2 100644 --- a/pkg/encryption/region_crypter.go +++ b/pkg/encryption/region_crypter.go @@ -16,6 +16,7 @@ package encryption import ( "crypto/aes" "crypto/cipher" + "reflect" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/encryptionpb" @@ -47,7 +48,8 @@ func EncryptRegion(region *metapb.Region, keyManager KeyManager) error { return errs.ErrEncryptionEncryptRegion.GenWithStack( "region already encrypted, region id = %d", region.Id) } - if keyManager == nil { + if keyManager == nil || + (reflect.TypeOf(keyManager).Kind() == reflect.Ptr && reflect.ValueOf(keyManager).IsNil()) { // encryption is not enabled. return nil } @@ -88,7 +90,8 @@ func DecryptRegion(region *metapb.Region, keyManager KeyManager) error { if region.EncryptionMeta == nil { return nil } - if keyManager == nil { + if keyManager == nil || + (reflect.TypeOf(keyManager).Kind() == reflect.Ptr && reflect.ValueOf(keyManager).IsNil()) { return errs.ErrEncryptionDecryptRegion.GenWithStack( "unable to decrypt region without encryption keys") } From e9d810dfcf619e0843f894192493470cff90876b Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 04:07:27 +0800 Subject: [PATCH 24/37] save conflict test Signed-off-by: Yi Wu --- server/encryptionkm/key_manager.go | 12 ++- server/encryptionkm/key_manager_test.go | 101 ++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index 50cc7b9f48c..3023c67fac2 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -45,10 +45,12 @@ const ( // Test helpers var ( - now = func() time.Time { return time.Now() } - tick = func(ticker *time.Ticker) <-chan time.Time { return ticker.C } - eventAfterReload = func() {} - eventAfterTicker = func() {} + now = func() time.Time { return time.Now() } + tick = func(ticker *time.Ticker) <-chan time.Time { return ticker.C } + eventAfterReload = func() {} + eventAfterTicker = func() {} + eventAfterLeaderCheck = func() {} + eventSaveKeysFailure = func() {} ) // KeyManager maintains the list to encryption keys. It handles encryption key generation and @@ -269,6 +271,7 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { m.leadership = nil return nil } + eventAfterLeaderCheck() // Reload encryption keys in case we are not up-to-date. keys, _, err := m.loadKeys() if err != nil { @@ -337,6 +340,7 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { // Store updated keys in etcd. err = saveKeys(m.etcdClient, m.leadership, m.masterKeyMeta, keys) if err != nil { + eventSaveKeysFailure() return err } // m.keys is not updated immediately. Defer to have watcher reload keys. diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index 5a40ca3ac2d..883c1070e2f 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -972,6 +972,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { <-reloadEvent // Check key is rotated. currentKeyID, currentKey, err := m.GetCurrentKey() + c.Assert(currentKeyID, Not(Equals), uint64(123)) c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) c.Assert(currentKey.Key, HasLen, 16) c.Assert(currentKey.CreationTime, Equals, uint64(mockNow)) @@ -986,3 +987,103 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, loadedKeys), IsTrue) } + +func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Mock time + originalNow := now + mockNow := int64(1601679533) + now = func() time.Time { return time.Unix(atomic.LoadInt64(&mockNow), 0) } + defer func() { now = originalNow }() + originalTick := tick + mockTick := make(chan time.Time) + tick = func(ticker *time.Ticker) <-chan time.Time { return mockTick } + defer func() { tick = originalTick }() + // Listen on ticker event + tickerEvent := make(chan struct{}, 1) + eventAfterTicker = func() { + var e struct{} + tickerEvent <- e + } + defer func() { eventAfterTicker = func() {} }() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with 100s rotation period. + rotationPeriod, err := time.ParseDuration("100s") + c.Assert(err, IsNil) + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + DataKeyRotationPeriod: typeutil.NewDuration(rotationPeriod), + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + // Check keys + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, keys), IsTrue) + // Invalidate leader after leader check. + eventAfterLeaderCheck = func() { + leadership.Reset() + } + defer func() { eventAfterLeaderCheck = func() {} }() + // Listen on save key failure event + saveKeysFailureEvent := make(chan struct{}, 1) + eventSaveKeysFailure = func() { + var e struct{} + saveKeysFailureEvent <- e + } + defer func() { eventSaveKeysFailure = func() {} }() + // Advance time and trigger ticker + atomic.AddInt64(&mockNow, int64(101)) + mockTick <- time.Unix(atomic.LoadInt64(&mockNow), 0) + <-tickerEvent + <-saveKeysFailureEvent + // Check keys is unchanged. + resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath) + c.Assert(err, IsNil) + storedKeys, err = loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, keys), IsTrue) +} From eb5b6b15af2ce33fe9f12f8adb95ab6e21005e0c Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 04:40:29 +0800 Subject: [PATCH 25/37] sanity check keys revision Signed-off-by: Yi Wu --- server/encryptionkm/key_manager.go | 39 ++++++++++++++++--------- server/encryptionkm/key_manager_test.go | 1 + 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index 3023c67fac2..fb1d39f64c6 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -67,8 +67,11 @@ type KeyManager struct { // Mutex for updating keys. Used for both of LoadKeys() and rotateKeyIfNeeded(). muUpdate sync.Mutex // PD leadership of the current PD node. Only the PD leader will rotate data keys, - // or change current encryption method. Guarded by muUpdate. + // or change current encryption method. + // Guarded by muUpdate. leadership *election.Leadership + // Revision of keys loaded from etcd. Guarded by muUpdate. + keysRevision int64 // List of all encryption keys and current encryption key id, // with type *encryptionpb.KeyDictionary keys atomic.Value @@ -174,16 +177,16 @@ func NewKeyManager( masterKeyMeta: masterKeyMeta, } // Load encryption keys from storage. - _, revision, err := m.loadKeys() + _, err = m.loadKeys() if err != nil { return nil, err } // Start periodic check for keys change and rotation key if needed. - go m.startBackgroundLoop(ctx, revision) + go m.startBackgroundLoop(ctx) return m, nil } -func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { +func (m *KeyManager) startBackgroundLoop(ctx context.Context) { // Create new context for the loop. loopCtx, _ := context.WithCancel(ctx) // Setup key dictionary watcher @@ -191,7 +194,7 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { defer watcher.Close() watcherCtx, cancel := context.WithCancel(ctx) defer cancel() - watchChan := watcher.Watch(watcherCtx, EncryptionKeysPath, clientv3.WithRev(revision)) + watchChan := watcher.Watch(watcherCtx, EncryptionKeysPath, clientv3.WithRev(m.keysRevision)) // Check data key rotation every min(dataKeyRotationPeriod, keyRotationCheckPeriod). checkPeriod := m.dataKeyRotationPeriod if keyRotationCheckPeriod < checkPeriod { @@ -205,7 +208,7 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { // Reload encryption keys updated by PD leader (could be ourselves). case resp := <-watchChan: if resp.Canceled { - // If the watcher failed, we rely solely on rotateKeyIfNeeded to reload encryption keys. + // If the watcher failed, we rely solely on rotateKeyIfNeeded() to reload encryption keys. log.Warn("encryption key watcher canceled") continue } @@ -241,28 +244,35 @@ func (m *KeyManager) startBackgroundLoop(ctx context.Context, revision int64) { func (m *KeyManager) loadKeysFromKV( kv *mvccpb.KeyValue, ) (*encryptionpb.KeyDictionary, error) { + // Sanity check if keys revision is in order. + // etcd docs indicates watcher event can be out of order: + // https://etcd.io/docs/v3.4.0/learning/api_guarantees/#isolation-level-and-consistency-of-replicas + if kv.ModRevision <= m.keysRevision { + return m.getKeys(), nil + } keys, err := loadKeysFromKV(kv) if err != nil { return nil, err } + m.keysRevision = kv.ModRevision m.keys.Store(keys) log.Info("reloaded encryption keys", zap.Int64("revision", kv.ModRevision)) return keys, nil } -func (m *KeyManager) loadKeys() (keys *encryptionpb.KeyDictionary, revision int64, err error) { +func (m *KeyManager) loadKeys() (keys *encryptionpb.KeyDictionary, err error) { resp, err := etcdutil.EtcdKVGet(m.etcdClient, EncryptionKeysPath) if err != nil { - return nil, 0, err + return nil, err } if resp == nil || len(resp.Kvs) == 0 { - return nil, 0, nil + return nil, nil } keys, err = m.loadKeysFromKV(resp.Kvs[0]) if err != nil { - return nil, 0, err + return nil, err } - return keys, resp.Kvs[0].ModRevision, err + return keys, err } func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { @@ -273,7 +283,7 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { } eventAfterLeaderCheck() // Reload encryption keys in case we are not up-to-date. - keys, _, err := m.loadKeys() + keys, err := m.loadKeys() if err != nil { return err } @@ -343,7 +353,8 @@ func (m *KeyManager) rotateKeyIfNeeded(forceUpdate bool) error { eventSaveKeysFailure() return err } - // m.keys is not updated immediately. Defer to have watcher reload keys. + // Reload keys. + _, err = m.loadKeys() return err } @@ -398,7 +409,7 @@ func (m *KeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { return key, nil } // Reload keys from storage. - keys, _, err := m.loadKeys() + keys, err := m.loadKeys() if err != nil { return nil, err } diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index 883c1070e2f..bfa28cf6c8e 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -344,6 +344,7 @@ func (s *testKeyManagerSuite) TestGetKey(c *C) { loadedKeys = proto.Clone(loadedKeys).(*encryptionpb.KeyDictionary) delete(loadedKeys.Keys, 456) m.keys.Store(loadedKeys) + m.keysRevision = 0 key, err = m.GetKey(uint64(456)) c.Assert(err, IsNil) c.Assert(proto.Equal(key, keys.Keys[456]), IsTrue) From c97b672d59922de452b0b5cba6f51390afbca4c6 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 06:42:02 +0800 Subject: [PATCH 26/37] kms Signed-off-by: Yi Wu --- go.mod | 2 + go.sum | 9 +++ pkg/encryption/kms.go | 114 ++++++++++++++++++++++++++++- pkg/encryption/master_key.go | 25 ++++--- server/encryptionkm/key_manager.go | 4 +- 5 files changed, 141 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index c6c9344b9e1..9e6b551e5b1 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/BurntSushi/toml v0.3.1 github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 + github.com/aws/aws-sdk-go v1.35.3 github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5 github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e github.com/coreos/go-semver v0.2.0 @@ -35,6 +36,7 @@ require ( github.com/pingcap/kvproto v0.0.0-20200927025644-73dc27044686 github.com/pingcap/log v0.0.0-20200511115504-543df19646ad github.com/pingcap/sysutil v0.0.0-20200715082929-4c47bcac246a + github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.0.0 github.com/prometheus/common v0.4.1 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/go.sum b/go.sum index 4db5a2a5dc9..8e6b18d1773 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/appleboy/gin-jwt/v2 v2.6.3 h1:aK4E3DjihWEBUTjEeRnGkA5nUkmwJPL1CPonMa2 github.com/appleboy/gin-jwt/v2 v2.6.3/go.mod h1:MfPYA4ogzvOcVkRwAxT7quHOtQmVKDpTwxyUrC2DNw0= github.com/appleboy/gofight/v2 v2.1.2 h1:VOy3jow4vIK8BRQJoC/I9muxyYlJ2yb9ht2hZoS3rf4= github.com/appleboy/gofight/v2 v2.1.2/go.mod h1:frW+U1QZEdDgixycTj4CygQ48yLTUhplt43+Wczp3rw= +github.com/aws/aws-sdk-go v1.35.3 h1:r0puXncSaAfRt7Btml2swUo74Kao+vKhO3VLjwDjK54= +github.com/aws/aws-sdk-go v1.35.3/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= @@ -193,6 +195,10 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo= @@ -504,6 +510,7 @@ golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7 h1:fHDIZ2oxGnUZRN6WgWFCbYBjH golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191002035440-2ec189313ef0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -620,6 +627,8 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/encryption/kms.go b/pkg/encryption/kms.go index eb613cceb14..f8b50aab600 100644 --- a/pkg/encryption/kms.go +++ b/pkg/encryption/kms.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +13,119 @@ package encryption +import ( + "os" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pkg/errors" +) + const ( // We only support AWS KMS right now. kmsVendorAWS = "AWS" + + // K8S IAM related environment variables. + envAwsRoleArn = "AWS_ROLE_ARN" + envAwsWebIdentityTokenFile = "AWS_WEB_IDENTITY_TOKEN_FILE" + envAwsRoleSessionName = "AWS_ROLE_SESSION_NAME" ) + +func newMasterKeyFromKMS( + config *encryptionpb.MasterKeyKms, + ciphertextKey []byte, +) (*MasterKey, error) { + if config == nil { + return nil, errors.New("missing master key KMS config") + } + if config.Vendor != kmsVendorAWS { + return nil, errors.Errorf("unsupported KMS vendor: %s", config.Vendor) + } + credentials, err := newAwsCredentials() + if err != nil { + return nil, err + } + session, err := session.NewSession(&aws.Config{ + Credentials: credentials, + Region: &config.Region, + Endpoint: &config.Endpoint, + }) + if err != nil { + return nil, errors.Wrap(err, "fail to create AWS session to access KMS CMK") + } + client := kms.New(session) + if len(ciphertextKey) == 0 { + numberOfBytes := int64(masterKeyLength) + // Create a new data key. + output, err := client.GenerateDataKey(&kms.GenerateDataKeyInput{ + KeyId: &config.KeyId, + NumberOfBytes: &numberOfBytes, + }) + if err != nil { + return nil, errors.Wrap(err, "fail to generate data key from AWS KMS") + } + if len(output.Plaintext) != masterKeyLength { + return nil, errors.Wrapf(err, + "unexpected data key length generated from AWS KMS, expectd %d vs actual %d", + masterKeyLength, len(output.Plaintext)) + } + return &MasterKey{ + key: output.Plaintext, + ciphertextKey: output.CiphertextBlob, + }, nil + } else { + // Decrypt existing data key. + output, err := client.Decrypt(&kms.DecryptInput{ + KeyId: &config.KeyId, + CiphertextBlob: ciphertextKey, + }) + if err != nil { + return nil, errors.Wrap(err, "fail to decrypt data key from AWS KMS") + } + if len(output.Plaintext) != masterKeyLength { + return nil, errors.Wrapf(err, + "unexpected data key length decrypted from AWS KMS, expected %d vs actual %d", + masterKeyLength, len(output.Plaintext)) + } + return &MasterKey{ + key: output.Plaintext, + ciphertextKey: ciphertextKey, + }, nil + } +} + +func newAwsCredentials() (*credentials.Credentials, error) { + var providers []credentials.Provider + + // Credentials from K8S IAM role. + roleArn := os.Getenv(envAwsRoleArn) + tokenFile := os.Getenv(envAwsWebIdentityTokenFile) + sessionName := os.Getenv(envAwsRoleSessionName) + // Session name is optional. + if roleArn != "" && tokenFile != "" { + session, err := session.NewSession() + if err != nil { + return nil, errors.Wrap(err, "fail to create AWS session to create a WebIdentityRoleProvider") + } + webIdentityProvider := stscreds.NewWebIdentityRoleProvider( + sts.New(session), roleArn, sessionName, tokenFile) + providers = append(providers, webIdentityProvider) + } + + // Credentials from AWS environment variables. + providers = append(providers, &credentials.EnvProvider{}) + + // Credentials from default AWS credentials file. + providers = append(providers, &credentials.SharedCredentialsProvider{ + Filename: "", + Profile: "", + }) + + credentials := credentials.NewChainCredentials(providers) + return credentials, nil +} diff --git a/pkg/encryption/master_key.go b/pkg/encryption/master_key.go index 952fa443bba..dc5cdeb466f 100644 --- a/pkg/encryption/master_key.go +++ b/pkg/encryption/master_key.go @@ -33,11 +33,13 @@ type MasterKey struct { // Encryption key in plaintext. If it is nil, encryption is no-op. // Never output it to info log or persist it on disk. key []byte + // Key in ciphertext form. Used by KMS key type. + ciphertextKey []byte } // NewMasterKey obtains a master key from backend specified by given config. // The config may be altered to fill in metadata generated when initializing the master key. -func NewMasterKey(config *encryptionpb.MasterKey) (*MasterKey, error) { +func NewMasterKey(config *encryptionpb.MasterKey, ciphertextKey []byte) (*MasterKey, error) { if config == nil { return nil, errs.ErrEncryptionNewMasterKey.GenWithStack("master key config is empty") } @@ -47,13 +49,10 @@ func NewMasterKey(config *encryptionpb.MasterKey) (*MasterKey, error) { }, nil } if file := config.GetFile(); file != nil { - key, err := newMasterKeyFromFile(file) - if err != nil { - return nil, err - } - return &MasterKey{ - key: key, - }, nil + return newMasterKeyFromFile(file) + } + if kms := config.GetKms(); kms != nil { + return newMasterKeyFromKMS(kms, ciphertextKey) } return nil, errors.New("unrecognized master key type") } @@ -84,10 +83,16 @@ func (k *MasterKey) IsPlaintext() bool { return k.key == nil } +// CiphertextKey returns the key in encrypted form. +// KMS key type recover the key by decrypting the ciphertextKey from KMS. +func (k *MasterKey) CiphertextKey() []byte { + return k.ciphertextKey +} + // newMasterKeyFromFile reads a hex-string from file specified in the config, and construct a // MasterKey object. The key must be of 256 bits (32 bytes). The file can contain leading and // tailing spaces. -func newMasterKeyFromFile(config *encryptionpb.MasterKeyFile) ([]byte, error) { +func newMasterKeyFromFile(config *encryptionpb.MasterKeyFile) (*MasterKey, error) { if config == nil { return nil, errs.ErrEncryptionNewMasterKey.GenWithStack("missing master key file config") } @@ -110,5 +115,5 @@ func newMasterKeyFromFile(config *encryptionpb.MasterKeyFile) ([]byte, error) { "unexpected key length from master key file, expected %d vs actual %d", masterKeyLength, len(key)) } - return key, nil + return &MasterKey{key: key}, nil } diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index fb1d39f64c6..18d2b71cd28 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -85,7 +85,7 @@ func saveKeys( keys *encryptionpb.KeyDictionary, ) error { // Get master key. - masterKey, err := encryption.NewMasterKey(masterKeyMeta) + masterKey, err := encryption.NewMasterKey(masterKeyMeta, nil) if err != nil { return err } @@ -139,7 +139,7 @@ func loadKeysFromKV(kv *mvccpb.KeyValue) (*encryptionpb.KeyDictionary, error) { return nil, errs.ErrEncryptionLoadKeys.GenWithStack( "no master key config found with encryption keys") } - masterKey, err := encryption.NewMasterKey(masterKeyConfig) + masterKey, err := encryption.NewMasterKey(masterKeyConfig, content.CiphertextKey) if err != nil { return nil, err } From bc3f35f88317ffa0d94deeca6bbbf35439fbedfa Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 14:30:25 +0800 Subject: [PATCH 27/37] test set ciphertextKey Signed-off-by: Yi Wu --- pkg/encryption/master_key.go | 9 +++ server/encryptionkm/key_manager.go | 12 +-- server/encryptionkm/key_manager_test.go | 98 +++++++++++++++++++++++-- 3 files changed, 109 insertions(+), 10 deletions(-) diff --git a/pkg/encryption/master_key.go b/pkg/encryption/master_key.go index dc5cdeb466f..6037974040a 100644 --- a/pkg/encryption/master_key.go +++ b/pkg/encryption/master_key.go @@ -57,6 +57,15 @@ func NewMasterKey(config *encryptionpb.MasterKey, ciphertextKey []byte) (*Master return nil, errors.New("unrecognized master key type") } +// NewCustomMasterKey construct a master key instance from raw key and ciphertext key bytes. +// Used for test only. +func NewCustomMasterKey(key []byte, ciphertextKey []byte) *MasterKey { + return &MasterKey{ + key: key, + ciphertextKey: ciphertextKey, + } +} + // Encrypt encrypts given plaintext using the master key. // IV is randomly generated and included in the result. Caller is expected to pass the same IV back // for decryption. diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index 18d2b71cd28..14c2c03a2f0 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -47,6 +47,7 @@ const ( var ( now = func() time.Time { return time.Now() } tick = func(ticker *time.Ticker) <-chan time.Time { return ticker.C } + newMasterKey = encryption.NewMasterKey eventAfterReload = func() {} eventAfterTicker = func() {} eventAfterLeaderCheck = func() {} @@ -85,7 +86,7 @@ func saveKeys( keys *encryptionpb.KeyDictionary, ) error { // Get master key. - masterKey, err := encryption.NewMasterKey(masterKeyMeta, nil) + masterKey, err := newMasterKey(masterKeyMeta, nil) if err != nil { return err } @@ -105,9 +106,10 @@ func saveKeys( return err } content := &encryptionpb.EncryptedContent{ - Content: ciphertextContent, - MasterKey: masterKeyMeta, - Iv: iv, + Content: ciphertextContent, + MasterKey: masterKeyMeta, + Iv: iv, + CiphertextKey: masterKey.CiphertextKey(), } value, err := proto.Marshal(content) if err != nil { @@ -139,7 +141,7 @@ func loadKeysFromKV(kv *mvccpb.KeyValue) (*encryptionpb.KeyDictionary, error) { return nil, errs.ErrEncryptionLoadKeys.GenWithStack( "no master key config found with encryption keys") } - masterKey, err := encryption.NewMasterKey(masterKeyConfig, content.CiphertextKey) + masterKey, err := newMasterKey(masterKeyConfig, content.CiphertextKey) if err != nil { return nil, err } diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index bfa28cf6c8e..bcacbd442a6 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -14,6 +14,7 @@ package encryptionkm import ( + "bytes" "context" "encoding/hex" "fmt" @@ -45,9 +46,10 @@ type testKeyManagerSuite struct{} var _ = Suite(&testKeyManagerSuite{}) const ( - testMasterKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b530" - testMasterKey2 = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b531" - testDataKey = "be798242dde0c40d9a65cdbc36c1c9ac" + testMasterKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b530" + testMasterKey2 = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b531" + testCiphertextKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b532" + testDataKey = "be798242dde0c40d9a65cdbc36c1c9ac" ) func getTestDataKey() []byte { @@ -114,11 +116,12 @@ func newTestLeader(c *C, client *clientv3.Client) *election.Leadership { return leader } -func checkMasterKeyMeta(c *C, value []byte, meta *encryptionpb.MasterKey) { +func checkMasterKeyMeta(c *C, value []byte, meta *encryptionpb.MasterKey, ciphertextKey []byte) { content := &encryptionpb.EncryptedContent{} err := content.Unmarshal(value) c.Assert(err, IsNil) c.Assert(proto.Equal(content.MasterKey, meta), IsTrue) + c.Assert(bytes.Equal(content.CiphertextKey, ciphertextKey), IsTrue) } func (s *testKeyManagerSuite) TestNewKeyManagerBasic(c *C) { @@ -820,7 +823,92 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { c.Assert(proto.Equal(storedKeys, keys), IsTrue) meta, err := config.GetMasterKeyMeta() c.Assert(err, IsNil) - checkMasterKeyMeta(c, resp.Kvs[0].Value, meta) + checkMasterKeyMeta(c, resp.Kvs[0].Value, meta, nil) +} + +func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Cancel background loop. + cancel() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with a different master key. + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // mock NewMasterKey + originalNewMasterKey := newMasterKey + newMasterKeyCalled := 0 + outputMasterKey, _ := hex.DecodeString(testMasterKey) + outputCiphertextKey, _ := hex.DecodeString(testCiphertextKey) + newMasterKey = func( + meta *encryptionpb.MasterKey, + ciphertext []byte, + ) (*encryption.MasterKey, error) { + if newMasterKeyCalled == 0 { + // called by saveKeys + c.Assert(ciphertext, IsNil) + } else { + // called by loadKeys after saveKeys + c.Assert(bytes.Equal(ciphertext, outputCiphertextKey), IsTrue) + } + newMasterKeyCalled += 1 + return encryption.NewCustomMasterKey(outputMasterKey, outputCiphertextKey), nil + } + defer func() { newMasterKey = originalNewMasterKey }() + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + c.Assert(newMasterKeyCalled, Equals, 2) + // Check if keys are the same + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, keys), IsTrue) + meta, err := config.GetMasterKeyMeta() + c.Assert(err, IsNil) + // Check ciphertext key is stored with keys. + checkMasterKeyMeta(c, resp.Kvs[0].Value, meta, outputCiphertextKey) } func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { From b5456dfb8abd06045f3fa34c23d8c2092e131597 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 6 Oct 2020 14:30:25 +0800 Subject: [PATCH 28/37] test set ciphertextKey Signed-off-by: Yi Wu --- pkg/encryption/master_key.go | 9 +++ server/encryptionkm/key_manager.go | 12 +-- server/encryptionkm/key_manager_test.go | 98 +++++++++++++++++++++++-- 3 files changed, 109 insertions(+), 10 deletions(-) diff --git a/pkg/encryption/master_key.go b/pkg/encryption/master_key.go index dc5cdeb466f..6037974040a 100644 --- a/pkg/encryption/master_key.go +++ b/pkg/encryption/master_key.go @@ -57,6 +57,15 @@ func NewMasterKey(config *encryptionpb.MasterKey, ciphertextKey []byte) (*Master return nil, errors.New("unrecognized master key type") } +// NewCustomMasterKey construct a master key instance from raw key and ciphertext key bytes. +// Used for test only. +func NewCustomMasterKey(key []byte, ciphertextKey []byte) *MasterKey { + return &MasterKey{ + key: key, + ciphertextKey: ciphertextKey, + } +} + // Encrypt encrypts given plaintext using the master key. // IV is randomly generated and included in the result. Caller is expected to pass the same IV back // for decryption. diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index 18d2b71cd28..14c2c03a2f0 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -47,6 +47,7 @@ const ( var ( now = func() time.Time { return time.Now() } tick = func(ticker *time.Ticker) <-chan time.Time { return ticker.C } + newMasterKey = encryption.NewMasterKey eventAfterReload = func() {} eventAfterTicker = func() {} eventAfterLeaderCheck = func() {} @@ -85,7 +86,7 @@ func saveKeys( keys *encryptionpb.KeyDictionary, ) error { // Get master key. - masterKey, err := encryption.NewMasterKey(masterKeyMeta, nil) + masterKey, err := newMasterKey(masterKeyMeta, nil) if err != nil { return err } @@ -105,9 +106,10 @@ func saveKeys( return err } content := &encryptionpb.EncryptedContent{ - Content: ciphertextContent, - MasterKey: masterKeyMeta, - Iv: iv, + Content: ciphertextContent, + MasterKey: masterKeyMeta, + Iv: iv, + CiphertextKey: masterKey.CiphertextKey(), } value, err := proto.Marshal(content) if err != nil { @@ -139,7 +141,7 @@ func loadKeysFromKV(kv *mvccpb.KeyValue) (*encryptionpb.KeyDictionary, error) { return nil, errs.ErrEncryptionLoadKeys.GenWithStack( "no master key config found with encryption keys") } - masterKey, err := encryption.NewMasterKey(masterKeyConfig, content.CiphertextKey) + masterKey, err := newMasterKey(masterKeyConfig, content.CiphertextKey) if err != nil { return nil, err } diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index bfa28cf6c8e..bcacbd442a6 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -14,6 +14,7 @@ package encryptionkm import ( + "bytes" "context" "encoding/hex" "fmt" @@ -45,9 +46,10 @@ type testKeyManagerSuite struct{} var _ = Suite(&testKeyManagerSuite{}) const ( - testMasterKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b530" - testMasterKey2 = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b531" - testDataKey = "be798242dde0c40d9a65cdbc36c1c9ac" + testMasterKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b530" + testMasterKey2 = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b531" + testCiphertextKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b532" + testDataKey = "be798242dde0c40d9a65cdbc36c1c9ac" ) func getTestDataKey() []byte { @@ -114,11 +116,12 @@ func newTestLeader(c *C, client *clientv3.Client) *election.Leadership { return leader } -func checkMasterKeyMeta(c *C, value []byte, meta *encryptionpb.MasterKey) { +func checkMasterKeyMeta(c *C, value []byte, meta *encryptionpb.MasterKey, ciphertextKey []byte) { content := &encryptionpb.EncryptedContent{} err := content.Unmarshal(value) c.Assert(err, IsNil) c.Assert(proto.Equal(content.MasterKey, meta), IsTrue) + c.Assert(bytes.Equal(content.CiphertextKey, ciphertextKey), IsTrue) } func (s *testKeyManagerSuite) TestNewKeyManagerBasic(c *C) { @@ -820,7 +823,92 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { c.Assert(proto.Equal(storedKeys, keys), IsTrue) meta, err := config.GetMasterKeyMeta() c.Assert(err, IsNil) - checkMasterKeyMeta(c, resp.Kvs[0].Value, meta) + checkMasterKeyMeta(c, resp.Kvs[0].Value, meta, nil) +} + +func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) { + // Initialize. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, cleanupEtcd := newTestEtcd(c) + defer cleanupEtcd() + keyFile, cleanupKeyFile := newTestKeyFile(c) + defer cleanupKeyFile() + leadership := newTestLeader(c, client) + // Cancel background loop. + cancel() + // Update keys in etcd + masterKeyMeta := &encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: keyFile, + }, + }, + } + keys := &encryptionpb.KeyDictionary{ + CurrentKeyId: 123, + Keys: map[uint64]*encryptionpb.DataKey{ + 123: &encryptionpb.DataKey{ + Key: getTestDataKey(), + Method: encryptionpb.EncryptionMethod_AES128_CTR, + CreationTime: uint64(1601679533), + WasExposed: false, + }, + }, + } + err := saveKeys(client, leadership, masterKeyMeta, keys) + c.Assert(err, IsNil) + // Config with a different master key. + config := &encryption.Config{ + DataEncryptionMethod: "aes128-ctr", + MasterKey: encryption.MasterKeyConfig{ + Type: "file", + MasterKeyFileConfig: encryption.MasterKeyFileConfig{ + FilePath: keyFile, + }, + }, + } + err = config.Adjust() + c.Assert(err, IsNil) + // Create the key manager. + m, err := NewKeyManager(ctx, client, config) + c.Assert(err, IsNil) + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + // mock NewMasterKey + originalNewMasterKey := newMasterKey + newMasterKeyCalled := 0 + outputMasterKey, _ := hex.DecodeString(testMasterKey) + outputCiphertextKey, _ := hex.DecodeString(testCiphertextKey) + newMasterKey = func( + meta *encryptionpb.MasterKey, + ciphertext []byte, + ) (*encryption.MasterKey, error) { + if newMasterKeyCalled == 0 { + // called by saveKeys + c.Assert(ciphertext, IsNil) + } else { + // called by loadKeys after saveKeys + c.Assert(bytes.Equal(ciphertext, outputCiphertextKey), IsTrue) + } + newMasterKeyCalled += 1 + return encryption.NewCustomMasterKey(outputMasterKey, outputCiphertextKey), nil + } + defer func() { newMasterKey = originalNewMasterKey }() + // Set leadership + err = m.SetLeadership(leadership) + c.Assert(err, IsNil) + c.Assert(newMasterKeyCalled, Equals, 2) + // Check if keys are the same + c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) + c.Assert(err, IsNil) + storedKeys, err := loadKeysFromKV(resp.Kvs[0]) + c.Assert(err, IsNil) + c.Assert(proto.Equal(storedKeys, keys), IsTrue) + meta, err := config.GetMasterKeyMeta() + c.Assert(err, IsNil) + // Check ciphertext key is stored with keys. + checkMasterKeyMeta(c, resp.Kvs[0].Value, meta, outputCiphertextKey) } func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { From fcf3fe2ddfc68b89b1143fab9ead7f8ba119754a Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Sat, 10 Oct 2020 04:19:04 +0800 Subject: [PATCH 29/37] clone region only when needed Signed-off-by: Yi Wu --- pkg/encryption/region_crypter.go | 34 +++++++++++++++------------ pkg/encryption/region_crypter_test.go | 26 +++++++++++--------- server/core/region_storage.go | 4 +--- server/core/storage.go | 3 +-- 4 files changed, 36 insertions(+), 31 deletions(-) diff --git a/pkg/encryption/region_crypter.go b/pkg/encryption/region_crypter.go index bd379d4b92a..6762086d525 100644 --- a/pkg/encryption/region_crypter.go +++ b/pkg/encryption/region_crypter.go @@ -17,6 +17,7 @@ import ( "crypto/aes" "crypto/cipher" + "github.com/gogo/protobuf/proto" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -36,46 +37,48 @@ func processRegionKeys(region *metapb.Region, key *encryptionpb.DataKey, iv []by return nil } -// EncryptRegion encrypt the region start key and end key in-place, -// using the current key return from the key manager. Encryption meta is updated accordingly. -// Note: Call may need to make deep copy of the object if changing the object is undesired. -func EncryptRegion(region *metapb.Region, keyManager KeyManager) error { +// EncryptRegion encrypt the region start key and end key, using the current key return from the +// key manager. The return is an encypted copy of the region, with Encryption meta updated. +func EncryptRegion(region *metapb.Region, keyManager KeyManager) (*metapb.Region, error) { if region == nil { - return errs.ErrEncryptionEncryptRegion.GenWithStack("trying to encrypt nil region") + return nil, errs.ErrEncryptionEncryptRegion.GenWithStack("trying to encrypt nil region") } if region.EncryptionMeta != nil { - return errs.ErrEncryptionEncryptRegion.GenWithStack( + return nil, errs.ErrEncryptionEncryptRegion.GenWithStack( "region already encrypted, region id = %d", region.Id) } if keyManager == nil { // encryption is not enabled. - return nil + return region, nil } keyID, key, err := keyManager.GetCurrentKey() if err != nil { - return err + return nil, err } if key == nil { // encryption is not enabled. - return nil + return region, nil } err = CheckEncryptionMethodSupported(key.Method) if err != nil { - return err + return nil, err } iv, err := NewIvCTR() if err != nil { - return err + return nil, err } - err = processRegionKeys(region, key, iv) + // Deep copy region before altering it. + outRegion := proto.Clone(region).(*metapb.Region) + // Encrypt and update in-place. + err = processRegionKeys(outRegion, key, iv) if err != nil { - return err + return nil, err } - region.EncryptionMeta = &encryptionpb.EncryptionMeta{ + outRegion.EncryptionMeta = &encryptionpb.EncryptionMeta{ KeyId: keyID, Iv: iv, } - return nil + return outRegion, nil } // DecryptRegion decrypt the region start key and end key, if the region object was encrypted. @@ -100,6 +103,7 @@ func DecryptRegion(region *metapb.Region, keyManager KeyManager) error { if err != nil { return err } + // Decrypt and update in-place. err = processRegionKeys(region, key, region.EncryptionMeta.Iv) if err != nil { return err diff --git a/pkg/encryption/region_crypter_test.go b/pkg/encryption/region_crypter_test.go index 45f94fde228..9778ae02981 100644 --- a/pkg/encryption/region_crypter_test.go +++ b/pkg/encryption/region_crypter_test.go @@ -78,8 +78,11 @@ func (m *testKeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { func (s *testRegionCrypterSuite) TestNilRegion(c *C) { m := newTestKeyManager() - c.Assert(EncryptRegion(nil, m), Not(IsNil)) - c.Assert(DecryptRegion(nil, m), Not(IsNil)) + region, err := EncryptRegion(nil, m) + c.Assert(err, NotNil) + c.Assert(region, IsNil) + err = DecryptRegion(nil, m) + c.Assert(err, NotNil) } func (s *testRegionCrypterSuite) TestEncryptRegionWithoutKeyManager(c *C) { @@ -89,7 +92,7 @@ func (s *testRegionCrypterSuite) TestEncryptRegionWithoutKeyManager(c *C) { EndKey: []byte("xyz"), EncryptionMeta: nil, } - err := EncryptRegion(region, nil) + region, err := EncryptRegion(region, nil) c.Assert(err, IsNil) // check the region isn't changed c.Assert(string(region.StartKey), Equals, "abc") @@ -106,7 +109,7 @@ func (s *testRegionCrypterSuite) TestEncryptRegionWhileEncryptionDisabled(c *C) } m := newTestKeyManager() m.EncryptionEnabled = false - err := EncryptRegion(region, m) + region, err := EncryptRegion(region, m) c.Assert(err, IsNil) // check the region isn't changed c.Assert(string(region.StartKey), Equals, "abc") @@ -126,24 +129,25 @@ func (s *testRegionCrypterSuite) TestEncryptRegion(c *C) { copy(region.StartKey, startKey) copy(region.EndKey, endKey) m := newTestKeyManager() - err := EncryptRegion(region, m) + outRegion, err := EncryptRegion(region, m) c.Assert(err, IsNil) + c.Assert(outRegion, Not(Equals), region) // check region is encrypted - c.Assert(region.EncryptionMeta, Not(IsNil)) - c.Assert(region.EncryptionMeta.KeyId, Equals, uint64(2)) - c.Assert(len(region.EncryptionMeta.Iv), Equals, ivLengthCTR) + c.Assert(outRegion.EncryptionMeta, Not(IsNil)) + c.Assert(outRegion.EncryptionMeta.KeyId, Equals, uint64(2)) + c.Assert(outRegion.EncryptionMeta.Iv, HasLen, ivLengthCTR) // Check encrypted content _, currentKey, err := m.GetCurrentKey() c.Assert(err, IsNil) block, err := aes.NewCipher(currentKey.Key) c.Assert(err, IsNil) - stream := cipher.NewCTR(block, region.EncryptionMeta.Iv) + stream := cipher.NewCTR(block, outRegion.EncryptionMeta.Iv) ciphertextStartKey := make([]byte, len(startKey)) stream.XORKeyStream(ciphertextStartKey, startKey) - c.Assert(string(region.StartKey), Equals, string(ciphertextStartKey)) + c.Assert(string(outRegion.StartKey), Equals, string(ciphertextStartKey)) ciphertextEndKey := make([]byte, len(endKey)) stream.XORKeyStream(ciphertextEndKey, endKey) - c.Assert(string(region.EndKey), Equals, string(ciphertextEndKey)) + c.Assert(string(outRegion.EndKey), Equals, string(ciphertextEndKey)) } func (s *testRegionCrypterSuite) TestDecryptRegionNotEncrypted(c *C) { diff --git a/server/core/region_storage.go b/server/core/region_storage.go index 3cf688d47fe..d5ee5d546bb 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -19,7 +19,6 @@ import ( "sync" "time" - "github.com/gogo/protobuf/proto" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/encryption" @@ -105,8 +104,7 @@ func (s *RegionStorage) backgroundFlush() { // SaveRegion saves one region to storage. func (s *RegionStorage) SaveRegion(region *metapb.Region) error { - region = proto.Clone(region).(*metapb.Region) - err := encryption.EncryptRegion(region, s.encryptionKeyManager) + region, err := encryption.EncryptRegion(region, s.encryptionKeyManager) if err != nil { return err } diff --git a/server/core/storage.go b/server/core/storage.go index 1f811994d89..00efe2fa404 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -605,8 +605,7 @@ func saveRegion( encryptionKeyManager *encryptionkm.KeyManager, region *metapb.Region, ) error { - region = proto.Clone(region).(*metapb.Region) - err := encryption.EncryptRegion(region, encryptionKeyManager) + region, err := encryption.EncryptRegion(region, encryptionKeyManager) if err != nil { return err } From 4f2332f45e626f9e08ea8a6028b8a13ab92db97c Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Sat, 10 Oct 2020 05:14:58 +0800 Subject: [PATCH 30/37] add test for config Signed-off-by: Yi Wu --- pkg/encryption/config.go | 4 +++ pkg/encryption/config_test.go | 55 +++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 pkg/encryption/config_test.go diff --git a/pkg/encryption/config.go b/pkg/encryption/config.go index e690f8bfec5..c58a04bce36 100644 --- a/pkg/encryption/config.go +++ b/pkg/encryption/config.go @@ -62,6 +62,10 @@ func (c *Config) Adjust() error { defaultDataKeyRotationPeriod) } c.DataKeyRotationPeriod.Duration = duration + } else if c.DataKeyRotationPeriod.Duration < 0 { + return errs.ErrEncryptionInvalidConfig.GenWithStack( + "negative data-key-rotation-period %d", + c.DataKeyRotationPeriod.Duration) } if len(c.MasterKey.Type) == 0 { c.MasterKey.Type = masterKeyTypePlaintext diff --git a/pkg/encryption/config_test.go b/pkg/encryption/config_test.go new file mode 100644 index 00000000000..79253c23144 --- /dev/null +++ b/pkg/encryption/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package encryption + +import ( + "testing" + "time" + + . "github.com/pingcap/check" + "github.com/tikv/pd/pkg/typeutil" +) + +func TestConfig(t *testing.T) { + TestingT(t) +} + +type testConfigSuite struct{} + +var _ = Suite(&testConfigSuite{}) + +func (s *testConfigSuite) TestAdjustDefaultValue(c *C) { + config := &Config{} + err := config.Adjust() + c.Assert(err, IsNil) + c.Assert(config.DataEncryptionMethod, Equals, methodPlaintext) + defaultRotationPeriod, _ := time.ParseDuration(defaultDataKeyRotationPeriod) + c.Assert(config.DataKeyRotationPeriod.Duration, Equals, defaultRotationPeriod) + c.Assert(config.MasterKey.Type, Equals, masterKeyTypePlaintext) +} + +func (s *testConfigSuite) TestAdjustInvalidDataEncryptionMethod(c *C) { + config := &Config{DataEncryptionMethod: "unknown"} + c.Assert(config.Adjust(), NotNil) +} + +func (s *testConfigSuite) TestAdjustNegativeRotationDuration(c *C) { + config := &Config{DataKeyRotationPeriod: typeutil.NewDuration(time.Duration(int64(-1)))} + c.Assert(config.Adjust(), NotNil) +} + +func (s *testConfigSuite) TestAdjustInvalidMasterKeyType(c *C) { + config := &Config{MasterKey: MasterKeyConfig{Type: "unknown"}} + c.Assert(config.Adjust(), NotNil) +} From 1d42e57c2122280fe15b4c2dd7d33f3841671eef Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Sat, 10 Oct 2020 06:15:45 +0800 Subject: [PATCH 31/37] update errors Signed-off-by: Yi Wu --- go.mod | 1 - pkg/encryption/kms.go | 21 +++++++++++++-------- pkg/errs/errno.go | 1 + 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 9e6b551e5b1..23191792edf 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,6 @@ require ( github.com/pingcap/kvproto v0.0.0-20200927025644-73dc27044686 github.com/pingcap/log v0.0.0-20200511115504-543df19646ad github.com/pingcap/sysutil v0.0.0-20200715082929-4c47bcac246a - github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.0.0 github.com/prometheus/common v0.4.1 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/pkg/encryption/kms.go b/pkg/encryption/kms.go index f8b50aab600..83c3f1fc2cd 100644 --- a/pkg/encryption/kms.go +++ b/pkg/encryption/kms.go @@ -22,8 +22,9 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/sts" + "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/encryptionpb" - "github.com/pkg/errors" + "github.com/tikv/pd/pkg/errs" ) const ( @@ -44,7 +45,7 @@ func newMasterKeyFromKMS( return nil, errors.New("missing master key KMS config") } if config.Vendor != kmsVendorAWS { - return nil, errors.Errorf("unsupported KMS vendor: %s", config.Vendor) + return nil, errs.ErrEncryptionKMS.GenWithStack("unsupported KMS vendor: %s", config.Vendor) } credentials, err := newAwsCredentials() if err != nil { @@ -56,7 +57,8 @@ func newMasterKeyFromKMS( Endpoint: &config.Endpoint, }) if err != nil { - return nil, errors.Wrap(err, "fail to create AWS session to access KMS CMK") + return nil, errs.ErrEncryptionKMS.Wrap(err).GenWithStack( + "fail to create AWS session to access KMS CMK") } client := kms.New(session) if len(ciphertextKey) == 0 { @@ -67,10 +69,11 @@ func newMasterKeyFromKMS( NumberOfBytes: &numberOfBytes, }) if err != nil { - return nil, errors.Wrap(err, "fail to generate data key from AWS KMS") + return nil, errs.ErrEncryptionKMS.Wrap(err).GenWithStack( + "fail to generate data key from AWS KMS") } if len(output.Plaintext) != masterKeyLength { - return nil, errors.Wrapf(err, + return nil, errs.ErrEncryptionKMS.GenWithStack( "unexpected data key length generated from AWS KMS, expectd %d vs actual %d", masterKeyLength, len(output.Plaintext)) } @@ -85,10 +88,11 @@ func newMasterKeyFromKMS( CiphertextBlob: ciphertextKey, }) if err != nil { - return nil, errors.Wrap(err, "fail to decrypt data key from AWS KMS") + return nil, errs.ErrEncryptionKMS.Wrap(err).GenWithStack( + "fail to decrypt data key from AWS KMS") } if len(output.Plaintext) != masterKeyLength { - return nil, errors.Wrapf(err, + return nil, errs.ErrEncryptionKMS.GenWithStack( "unexpected data key length decrypted from AWS KMS, expected %d vs actual %d", masterKeyLength, len(output.Plaintext)) } @@ -110,7 +114,8 @@ func newAwsCredentials() (*credentials.Credentials, error) { if roleArn != "" && tokenFile != "" { session, err := session.NewSession() if err != nil { - return nil, errors.Wrap(err, "fail to create AWS session to create a WebIdentityRoleProvider") + return nil, errs.ErrEncryptionKMS.Wrap(err).GenWithStack( + "fail to create AWS session to create a WebIdentityRoleProvider") } webIdentityProvider := stscreds.NewWebIdentityRoleProvider( sts.New(session), roleArn, sessionName, tokenFile) diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index 79e04d562a2..942cf383b98 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -284,4 +284,5 @@ var ( ErrEncryptionLoadKeys = errors.Normalize("load data keys error", errors.RFCCodeText("PD:encryption:ErrEncryptionLoadKeys")) ErrEncryptionRotateDataKey = errors.Normalize("failed to rotate data key", errors.RFCCodeText("PD:encryption:ErrEncryptionRotateDataKey")) ErrEncryptionSaveDataKeys = errors.Normalize("failed to save data keys", errors.RFCCodeText("PD:encryption:ErrEncryptionSaveDataKeys")) + ErrEncryptionKMS = errors.Normalize("KMS error", errors.RFCCodeText("PD:ErrEncryptionKMS")) ) From 9cbc5ff6571e9848e22c3afc00e19ef02bd03641 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 3 Nov 2020 09:46:41 +0800 Subject: [PATCH 32/37] fix lint Signed-off-by: Yi Wu --- server/encryptionkm/key_manager_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index 670687cff2d..1ec88077d18 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -893,7 +893,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) keys := &encryptionpb.KeyDictionary{ CurrentKeyId: 123, Keys: map[uint64]*encryptionpb.DataKey{ - 123: &encryptionpb.DataKey{ + 123: { Key: getTestDataKey(), Method: encryptionpb.EncryptionMethod_AES128_CTR, CreationTime: uint64(1601679533), From 813a8d960ead82f3f78c66ac6963deeb3e4d1c34 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 3 Nov 2020 12:28:35 +0800 Subject: [PATCH 33/37] fix lint Signed-off-by: Yi Wu --- pkg/encryption/kms.go | 11 ++++++----- pkg/encryption/master_key.go | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pkg/encryption/kms.go b/pkg/encryption/kms.go index 4494dffcf7d..040bd4936d7 100644 --- a/pkg/encryption/kms.go +++ b/pkg/encryption/kms.go @@ -40,7 +40,7 @@ const ( func newMasterKeyFromKMS( config *encryptionpb.MasterKeyKms, ciphertextKey []byte, -) (*MasterKey, error) { +) (masterKey *MasterKey, err error) { if config == nil { return nil, errors.New("missing master key KMS config") } @@ -77,10 +77,10 @@ func newMasterKeyFromKMS( "unexpected data key length generated from AWS KMS, expectd %d vs actual %d", masterKeyLength, len(output.Plaintext)) } - return &MasterKey{ + masterKey = &MasterKey{ key: output.Plaintext, ciphertextKey: output.CiphertextBlob, - }, nil + } } else { // Decrypt existing data key. output, err := client.Decrypt(&kms.DecryptInput{ @@ -96,11 +96,12 @@ func newMasterKeyFromKMS( "unexpected data key length decrypted from AWS KMS, expected %d vs actual %d", masterKeyLength, len(output.Plaintext)) } - return &MasterKey{ + masterKey = &MasterKey{ key: output.Plaintext, ciphertextKey: ciphertextKey, - }, nil + } } + return } func newAwsCredentials() (*credentials.Credentials, error) { diff --git a/pkg/encryption/master_key.go b/pkg/encryption/master_key.go index ae52c0df511..e8e340ac0cb 100644 --- a/pkg/encryption/master_key.go +++ b/pkg/encryption/master_key.go @@ -56,7 +56,7 @@ func NewMasterKey(config *encryptionpb.MasterKey, ciphertextKey []byte) (*Master return nil, errs.ErrEncryptionNewMasterKey.GenWithStack("unrecognized master key type") } -// NewCustomMasterKey construct a master key instance from raw key and ciphertext key bytes. +// NewCustomMasterKeyForTest construct a master key instance from raw key and ciphertext key bytes. // Used for test only. func NewCustomMasterKeyForTest(key []byte, ciphertextKey []byte) *MasterKey { return &MasterKey{ From 7cf9888c007d2a32b36c596f8ab29c5f58b44c53 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 3 Nov 2020 14:16:46 +0800 Subject: [PATCH 34/37] make errdoc Signed-off-by: Yi Wu --- errors.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/errors.toml b/errors.toml index 9987ee4041c..69d82cda524 100644 --- a/errors.toml +++ b/errors.toml @@ -1,6 +1,11 @@ # AUTOGENERATED BY github.com/pingcap/tiup/components/errdoc/errdoc-gen # DO NOT EDIT THIS FILE, PLEASE CHANGE ERROR DEFINITION IF CONTENT IMPROPER. +["PD:ErrEncryptionKMS"] +error = ''' +KMS error +''' + ["PD:apiutil:ErrRedirect"] error = ''' redirect failed From 9a6020fdc8d3355606009e68d11e7e715187a8d3 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Mon, 9 Nov 2020 13:34:16 +0800 Subject: [PATCH 35/37] address comment Signed-off-by: Yi Wu --- server/encryptionkm/key_manager.go | 18 ++------- server/encryptionkm/key_manager_test.go | 50 ++++++++++++------------- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/server/encryptionkm/key_manager.go b/server/encryptionkm/key_manager.go index 04b9c77e651..fbce747e7ed 100644 --- a/server/encryptionkm/key_manager.go +++ b/server/encryptionkm/key_manager.go @@ -76,15 +76,10 @@ func saveKeys( leadership *election.Leadership, masterKeyMeta *encryptionpb.MasterKey, keys *encryptionpb.KeyDictionary, - helper ...keyManagerHelper, + helper keyManagerHelper, ) (err error) { // Get master key. - var masterKey *encryption.MasterKey - if len(helper) > 0 { - masterKey, err = helper[0].newMasterKey(masterKeyMeta, nil) - } else { - masterKey, err = encryption.NewMasterKey(masterKeyMeta, nil) - } + masterKey, err := helper.newMasterKey(masterKeyMeta, nil) if err != nil { return err } @@ -133,7 +128,7 @@ func saveKeys( // extractKeysFromKV unpack encrypted keys from etcd KV. func extractKeysFromKV( kv *mvccpb.KeyValue, - helper ...keyManagerHelper, + helper keyManagerHelper, ) (*encryptionpb.KeyDictionary, error) { content := &encryptionpb.EncryptedContent{} err := content.Unmarshal(kv.Value) @@ -146,12 +141,7 @@ func extractKeysFromKV( return nil, errs.ErrEncryptionLoadKeys.GenWithStack( "no master key config found with encryption keys") } - var masterKey *encryption.MasterKey - if len(helper) > 0 { - masterKey, err = helper[0].newMasterKey(masterKeyConfig, content.CiphertextKey) - } else { - masterKey, err = encryption.NewMasterKey(masterKeyConfig, content.CiphertextKey) - } + masterKey, err := helper.newMasterKey(masterKeyConfig, content.CiphertextKey) if err != nil { return nil, err } diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index 1ec88077d18..ce67995fb22 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -215,7 +215,7 @@ func (s *testKeyManagerSuite) TestNewKeyManagerLoadKeys(c *C) { }, }, } - err = saveKeys(leadership, masterKeyMeta, keys) + err = saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Create the key manager. m, err := NewKeyManager(client, config) @@ -228,7 +228,7 @@ func (s *testKeyManagerSuite) TestNewKeyManagerLoadKeys(c *C) { // Check etcd KV. resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, keys), IsTrue) } @@ -308,7 +308,7 @@ func (s *testKeyManagerSuite) TestGetKey(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Use default config. config := &encryption.Config{} @@ -363,7 +363,7 @@ func (s *testKeyManagerSuite) TestLoadKeyEmpty(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Use default config. config := &encryption.Config{} @@ -427,7 +427,7 @@ func (s *testKeyManagerSuite) TestWatcher(c *C) { }, }, } - err = saveKeys(leadership, masterKeyMeta, keys) + err = saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) <-reloadEvent key, err := m.GetKey(123) @@ -453,7 +453,7 @@ func (s *testKeyManagerSuite) TestWatcher(c *C) { }, }, } - err = saveKeys(leadership, masterKeyMeta, keys) + err = saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) <-reloadEvent key, err = m.GetKey(123) @@ -536,7 +536,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionEnabling(c *C) { c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) } @@ -579,7 +579,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Config with different encrption method. config := &encryption.Config{ @@ -613,7 +613,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) } @@ -656,7 +656,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Config with different encrption method. config := &encryption.Config{ @@ -691,7 +691,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) } @@ -734,7 +734,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Config with 100s rotation period. rotationPeriod, err := time.ParseDuration("100s") @@ -773,7 +773,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) } @@ -818,7 +818,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Config with a different master key. config := &encryption.Config{ @@ -845,7 +845,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, keys), IsTrue) meta, err := config.GetMasterKeyMeta() @@ -901,7 +901,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Config with a different master key. config := &encryption.Config{ @@ -927,7 +927,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, keys), IsTrue) meta, err := config.GetMasterKeyMeta() @@ -972,7 +972,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Use default config. config := &encryption.Config{} @@ -994,7 +994,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), expectedKeys), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, expectedKeys), IsTrue) } @@ -1046,7 +1046,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper) c.Assert(err, IsNil) // Config with 100s rotation period. rotationPeriod, err := time.ParseDuration("100s") @@ -1075,7 +1075,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, keys), IsTrue) // Advance time and trigger ticker @@ -1097,7 +1097,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err = extractKeysFromKV(resp.Kvs[0]) + storedKeys, err = extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, loadedKeys), IsTrue) } @@ -1159,7 +1159,7 @@ func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Config with 100s rotation period. rotationPeriod, err := time.ParseDuration("100s") @@ -1188,7 +1188,7 @@ func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err := extractKeysFromKV(resp.Kvs[0]) + storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, keys), IsTrue) // Invalidate leader after leader check. @@ -1202,7 +1202,7 @@ func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { // Check keys is unchanged. resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath) c.Assert(err, IsNil) - storedKeys, err = extractKeysFromKV(resp.Kvs[0]) + storedKeys, err = extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) c.Assert(err, IsNil) c.Assert(proto.Equal(storedKeys, keys), IsTrue) } From 0321a8028a7a0d2474f67176578d1a949285f7ab Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Mon, 9 Nov 2020 13:38:22 +0800 Subject: [PATCH 36/37] fix error type Signed-off-by: Yi Wu --- pkg/encryption/kms.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/encryption/kms.go b/pkg/encryption/kms.go index 040bd4936d7..7251e933523 100644 --- a/pkg/encryption/kms.go +++ b/pkg/encryption/kms.go @@ -22,7 +22,6 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/sts" - "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/tikv/pd/pkg/errs" ) @@ -42,7 +41,7 @@ func newMasterKeyFromKMS( ciphertextKey []byte, ) (masterKey *MasterKey, err error) { if config == nil { - return nil, errors.New("missing master key KMS config") + return nil, errs.ErrEncryptionNewMasterKey.GenWithStack("missing master key file config") } if config.Vendor != kmsVendorAWS { return nil, errs.ErrEncryptionKMS.GenWithStack("unsupported KMS vendor: %s", config.Vendor) From 312bb486d357ca5923a9f92a6b83bbe43a33d253 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Mon, 9 Nov 2020 14:22:01 +0800 Subject: [PATCH 37/37] fix test Signed-off-by: Yi Wu --- server/encryptionkm/key_manager_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index ce67995fb22..9425220f141 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -1046,7 +1046,7 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { }, }, } - err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper) + err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) c.Assert(err, IsNil) // Config with 100s rotation period. rotationPeriod, err := time.ParseDuration("100s")