From 1754e54388aa4dd3d43c7a6cc87001ce41ea36a7 Mon Sep 17 00:00:00 2001 From: Yi-Shu Tai Date: Sun, 1 Aug 2021 14:52:35 -0700 Subject: [PATCH] client: support Root CA rotation on server side --- .../pkg/transport/keepalive_listener_test.go | 2 +- client/pkg/transport/listener.go | 251 +++++++++++----- client/pkg/transport/listener_test.go | 278 ++++++++++++++++-- client/pkg/transport/listener_tls.go | 2 +- client/pkg/transport/timeout_transport.go | 2 +- .../pkg/transport/timeout_transport_test.go | 2 +- client/pkg/transport/tls.go | 2 +- client/pkg/transport/tls_test.go | 2 +- client/pkg/transport/transport.go | 2 +- client/pkg/transport/transport_test.go | 4 +- contrib/raftexample/raft.go | 2 + pkg/proxy/server.go | 9 +- pkg/proxy/server_test.go | 113 +++---- server/config/config.go | 2 +- server/embed/config.go | 15 +- server/embed/config_test.go | 4 +- server/embed/etcd.go | 10 +- server/embed/serve.go | 2 +- server/etcdmain/config.go | 2 + server/etcdmain/grpc_proxy.go | 4 +- server/etcdmain/util.go | 2 +- .../api/rafthttp/functional_test.go | 5 + server/etcdserver/api/rafthttp/transport.go | 2 +- server/etcdserver/api/rafthttp/util.go | 4 +- .../etcdserver/api/v2discovery/discovery.go | 2 +- tests/e2e/discovery_test.go | 6 +- tests/framework/integration.go | 2 +- tests/framework/integration/cluster.go | 14 +- tests/functional/agent/handler.go | 15 +- tests/integration/clientv3/metrics_test.go | 2 +- tests/integration/embed/embed_test.go | 4 +- tests/integration/lazy_cluster.go | 2 +- tests/integration/metrics_test.go | 2 +- tests/integration/root_ca_rotation_test.go | 233 +++++++++++++++ tests/integration/util_test.go | 10 +- tests/integration/v3_grpc_test.go | 28 +- tests/integration/v3_tls_test.go | 6 +- tools/etcd-dump-metrics/metrics.go | 2 +- 38 files changed, 817 insertions(+), 234 deletions(-) create mode 100644 tests/integration/root_ca_rotation_test.go diff --git a/client/pkg/transport/keepalive_listener_test.go b/client/pkg/transport/keepalive_listener_test.go index efe312d94a88..49bedea5ec1a 100644 --- a/client/pkg/transport/keepalive_listener_test.go +++ b/client/pkg/transport/keepalive_listener_test.go @@ -58,7 +58,7 @@ func TestNewKeepAliveListener(t *testing.T) { } tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile} tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) - tlscfg, err := tlsInfo.ServerConfig() + tlscfg, err := tlsInfo.ReloadableServerConfig() if err != nil { t.Fatalf("unexpected serverConfig error: %v", err) } diff --git a/client/pkg/transport/listener.go b/client/pkg/transport/listener.go index cbe3b3f891a2..524c63f01b39 100644 --- a/client/pkg/transport/listener.go +++ b/client/pkg/transport/listener.go @@ -30,6 +30,8 @@ import ( "os" "path/filepath" "strings" + "sync" + "sync/atomic" "time" "go.etcd.io/etcd/client/pkg/v3/fileutil" @@ -38,6 +40,10 @@ import ( "go.uber.org/zap" ) +const ( + defaultRootCAReloadDuration = 5 * time.Minute +) + // NewListener creates a new listner. func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) { return newListener(addr, scheme, WithTLSInfo(tlsinfo)) @@ -185,17 +191,51 @@ type TLSInfo struct { // EmptyCN indicates that the cert must have empty CN. // If true, ClientConfig() will return an error for a cert with non empty CN. EmptyCN bool + + tlsConfig atomic.Value // *tls.Config + refreshOnce sync.Once + RefreshDuration time.Duration + EnableRootCAReload bool + refreshDone chan struct{} } -func (info TLSInfo) String() string { +func (info *TLSInfo) String() string { return fmt.Sprintf("cert = %s, key = %s, client-cert=%s, client-key=%s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.ClientCertFile, info.ClientKeyFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile) } -func (info TLSInfo) Empty() bool { +func (info *TLSInfo) Empty() bool { return info.CertFile == "" && info.KeyFile == "" } -func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertValidity uint, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) { +func (info *TLSInfo) Clone() *TLSInfo { + return &TLSInfo{ + CertFile: info.CertFile, + KeyFile: info.KeyFile, + ClientCertFile: info.ClientCertFile, + ClientKeyFile: info.ClientKeyFile, + TrustedCAFile: info.TrustedCAFile, + ClientCertAuth: info.ClientCertAuth, + CRLFile: info.CRLFile, + InsecureSkipVerify: info.InsecureSkipVerify, + SkipClientSANVerify: info.SkipClientSANVerify, + ServerName: info.ServerName, + HandshakeFailure: info.HandshakeFailure, + CipherSuites: info.CipherSuites, + selfCert: info.selfCert, + parseFunc: info.parseFunc, + AllowedCN: info.AllowedCN, + AllowedHostname: info.AllowedHostname, + Logger: info.Logger, + EmptyCN: info.EmptyCN, + RefreshDuration: info.RefreshDuration, + EnableRootCAReload: info.EnableRootCAReload, + } +} + +func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertValidity uint, additionalUsages ...x509.ExtKeyUsage) (info *TLSInfo, err error) { + if info == nil { + info = &TLSInfo{} + } info.Logger = lg if selfSignedCertValidity == 0 { err = fmt.Errorf("selfSignedCertValidity is invalid,it should be greater than 0") @@ -334,6 +374,87 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali return SelfCert(lg, dirpath, hosts, selfSignedCertValidity) } +func (info *TLSInfo) startRefresh() { + info.refreshOnce.Do( + func() { + info.loadServerTlsConfig() + if info.EnableRootCAReload { + if info.RefreshDuration == 0 { + info.RefreshDuration = defaultRootCAReloadDuration + } + info.refreshDone = make(chan struct{}) + go info.tlsConfigRefreshLoop() + } + }, + ) +} + +func (info *TLSInfo) loadServerTlsConfig() { + if info.Logger != nil { + info.Logger.Info("tls config reload from files") + } + cfg, err := info.serverConfig() + if err == nil { + info.tlsConfig.Store(cfg) + } else { + if info.Logger != nil { + info.Logger.Error("reload tls config error:", zap.Error(err)) + } + } +} + +func (info *TLSInfo) tlsConfigRefreshLoop() { + ticker := time.NewTicker(info.RefreshDuration) + defer ticker.Stop() + for { + select { + case <-ticker.C: + info.loadServerTlsConfig() + case <-info.refreshDone: + return + } + } +} + +func (info *TLSInfo) getClientCertificate() (*tls.Certificate, error) { + certFile, keyFile := info.CertFile, info.KeyFile + if info.ClientCertFile != "" { + certFile, keyFile = info.ClientCertFile, info.ClientKeyFile + } + return info.getCertificates(certFile, keyFile) +} + +func (info *TLSInfo) getServerCertificates() (*tls.Certificate, error) { + return info.getCertificates(info.CertFile, info.KeyFile) +} + +func (info *TLSInfo) getCertificates(certFile, keyFile string) (*tls.Certificate, error) { + cert, err := tlsutil.NewCert(certFile, keyFile, info.parseFunc) + if os.IsNotExist(err) { + if info.Logger != nil { + info.Logger.Warn( + "failed to find cert files", + zap.String("cert-file", certFile), + zap.String("key-file", keyFile), + zap.Error(err), + ) + } + } else if err != nil { + if info.Logger != nil { + info.Logger.Warn( + "failed to create peer certificate", + zap.String("cert-file", certFile), + zap.String("key-file", keyFile), + zap.Error(err), + ) + } + } + if err != nil { + return nil, err + } + return cert, err +} + // baseConfig is called on initial TLS handshake start. // // Previously, @@ -354,7 +475,7 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali // handshake, in order to trigger (*tls.Config).GetCertificate and populate // rest of the certificates on every new TLS connection, even when client // SNI is empty (e.g. cert only includes IPs). -func (info TLSInfo) baseConfig() (*tls.Config, error) { +func (info *TLSInfo) baseConfig() (*tls.Config, error) { if info.KeyFile == "" || info.CertFile == "" { return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile) } @@ -415,82 +536,15 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) { return errors.New("client certificate authentication failed") } } - - // this only reloads certs when there's a client request - // TODO: support server-side refresh (e.g. inotify, SIGHUP), caching - cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) { - cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc) - if os.IsNotExist(err) { - if info.Logger != nil { - info.Logger.Warn( - "failed to find peer cert files", - zap.String("cert-file", info.CertFile), - zap.String("key-file", info.KeyFile), - zap.Error(err), - ) - } - } else if err != nil { - if info.Logger != nil { - info.Logger.Warn( - "failed to create peer certificate", - zap.String("cert-file", info.CertFile), - zap.String("key-file", info.KeyFile), - zap.Error(err), - ) - } - } - return cert, err - } - cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (cert *tls.Certificate, err error) { - certfile, keyfile := info.CertFile, info.KeyFile - if info.ClientCertFile != "" { - certfile, keyfile = info.ClientCertFile, info.ClientKeyFile - } - cert, err = tlsutil.NewCert(certfile, keyfile, info.parseFunc) - if os.IsNotExist(err) { - if info.Logger != nil { - info.Logger.Warn( - "failed to find client cert files", - zap.String("cert-file", certfile), - zap.String("key-file", keyfile), - zap.Error(err), - ) - } - } else if err != nil { - if info.Logger != nil { - info.Logger.Warn( - "failed to create client certificate", - zap.String("cert-file", certfile), - zap.String("key-file", keyfile), - zap.Error(err), - ) - } - } - return cert, err - } return cfg, nil } -// cafiles returns a list of CA file paths. -func (info TLSInfo) cafiles() []string { - cs := make([]string, 0) - if info.TrustedCAFile != "" { - cs = append(cs, info.TrustedCAFile) - } - return cs -} - -// ServerConfig generates a tls.Config object for use by an HTTP server. -func (info TLSInfo) ServerConfig() (*tls.Config, error) { +func (info *TLSInfo) serverConfig() (*tls.Config, error) { cfg, err := info.baseConfig() if err != nil { return nil, err } - if info.Logger == nil { - info.Logger = zap.NewNop() - } - cfg.ClientAuth = tls.NoClientCert if info.TrustedCAFile != "" || info.ClientCertAuth { cfg.ClientAuth = tls.RequireAndVerifyClientCert @@ -515,11 +569,44 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) { // setting Max TLS version to TLS 1.2 for go 1.13 cfg.MaxVersion = tls.VersionTLS12 + certs, err := info.getServerCertificates() + if err != nil { + return nil, err + } + cfg.Certificates = []tls.Certificate{*certs} return cfg, nil } +// cafiles returns a list of CA file paths. +func (info *TLSInfo) cafiles() []string { + cs := make([]string, 0) + if info.TrustedCAFile != "" { + cs = append(cs, info.TrustedCAFile) + } + return cs +} + +// ReloadableServerConfig generates a tls.Config object for use by an HTTP server. +func (info *TLSInfo) ReloadableServerConfig() (*tls.Config, error) { + info.startRefresh() + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + cfg, ok := info.tlsConfig.Load().(*tls.Config) + if !ok { + return nil, errors.New("server tls configuration not ready") + } + return cfg.Clone(), nil + }, + // Needed to tell go http server to serve http2 + NextProtos: []string{"h2"}, + GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return info.getClientCertificate() + }, + }, nil +} + // ClientConfig generates a tls.Config object for use by an HTTP client. -func (info TLSInfo) ClientConfig() (*tls.Config, error) { +func (info *TLSInfo) ClientConfig() (*tls.Config, error) { var cfg *tls.Config var err error @@ -574,9 +661,27 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) { // setting Max TLS version to TLS 1.2 for go 1.13 cfg.MaxVersion = tls.VersionTLS12 + cert, err := info.getClientCertificate() + if err != nil { + if info.Logger != nil { + info.Logger.Warn( + "cannot create client certificate", + zap.Error(err), + ) + } + } else { + cfg.Certificates = []tls.Certificate{*cert} + } + return cfg, nil } +func (info *TLSInfo) Close() { + if info.refreshDone != nil { + close(info.refreshDone) + } +} + // IsClosedConnError returns true if the error is from closing listener, cmux. // copied from golang.org/x/net/http2/http2.go func IsClosedConnError(err error) bool { diff --git a/client/pkg/transport/listener_test.go b/client/pkg/transport/listener_test.go index 11e2182fe387..07e0e9adb041 100644 --- a/client/pkg/transport/listener_test.go +++ b/client/pkg/transport/listener_test.go @@ -15,19 +15,30 @@ package transport import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" + "fmt" + "io/ioutil" + "math/big" "net" "net/http" "os" + "path" "testing" "time" + "go.uber.org/zap" "go.uber.org/zap/zaptest" ) -func createSelfCert(t *testing.T, hosts ...string) (*TLSInfo, error) { +func createSelfCert(t *testing.T) (*TLSInfo, error) { return createSelfCertEx(t, "127.0.0.1") } @@ -37,7 +48,7 @@ func createSelfCertEx(t *testing.T, host string, additionalUsages ...x509.ExtKey if err != nil { return nil, err } - return &info, nil + return info, nil } func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { @@ -53,7 +64,7 @@ func TestNewListenerTLSInfo(t *testing.T) { if err != nil { t.Fatalf("unable to create cert: %v", err) } - testNewListenerTLSInfoAccept(t, *tlsInfo) + testNewListenerTLSInfoAccept(t, tlsInfo) } func TestNewListenerWithOpts(t *testing.T) { @@ -218,8 +229,8 @@ func TestNewListenerWithSocketOpts(t *testing.T) { } } -func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { - ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo) +func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo *TLSInfo) { + ln, err := NewListener("127.0.0.1:0", "https", tlsInfo) if err != nil { t.Fatalf("unexpected NewListener error: %v", err) } @@ -243,6 +254,7 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { // with specified address in its certificate the connection is still accepted // if the flag SkipClientSANVerify is set (i.e. checkSAN() is disabled for the client side) func TestNewListenerTLSInfoSkipClientSANVerify(t *testing.T) { + tests := []struct { skipClientSANVerify bool goodClientHost bool @@ -349,7 +361,7 @@ func TestNewTransportTLSInfo(t *testing.T) { t.Fatalf("unable to create cert: %v", err) } - tests := []TLSInfo{ + tests := []*TLSInfo{ {}, { CertFile: tlsinfo.CertFile, @@ -380,7 +392,7 @@ func TestNewTransportTLSInfo(t *testing.T) { func TestTLSInfoNonexist(t *testing.T) { tlsInfo := TLSInfo{CertFile: "@badname", KeyFile: "@badname"} - _, err := tlsInfo.ServerConfig() + _, err := tlsInfo.serverConfig() werr := &os.PathError{ Op: "open", Path: "@badname", @@ -393,17 +405,17 @@ func TestTLSInfoNonexist(t *testing.T) { func TestTLSInfoEmpty(t *testing.T) { tests := []struct { - info TLSInfo + info *TLSInfo want bool }{ - {TLSInfo{}, true}, - {TLSInfo{TrustedCAFile: "baz"}, true}, - {TLSInfo{CertFile: "foo"}, false}, - {TLSInfo{KeyFile: "bar"}, false}, - {TLSInfo{CertFile: "foo", KeyFile: "bar"}, false}, - {TLSInfo{CertFile: "foo", TrustedCAFile: "baz"}, false}, - {TLSInfo{KeyFile: "bar", TrustedCAFile: "baz"}, false}, - {TLSInfo{CertFile: "foo", KeyFile: "bar", TrustedCAFile: "baz"}, false}, + {&TLSInfo{}, true}, + {&TLSInfo{TrustedCAFile: "baz"}, true}, + {&TLSInfo{CertFile: "foo"}, false}, + {&TLSInfo{KeyFile: "bar"}, false}, + {&TLSInfo{CertFile: "foo", KeyFile: "bar"}, false}, + {&TLSInfo{CertFile: "foo", TrustedCAFile: "baz"}, false}, + {&TLSInfo{KeyFile: "bar", TrustedCAFile: "baz"}, false}, + {&TLSInfo{CertFile: "foo", KeyFile: "bar", TrustedCAFile: "baz"}, false}, } for i, tt := range tests { @@ -420,7 +432,7 @@ func TestTLSInfoMissingFields(t *testing.T) { t.Fatalf("unable to create cert: %v", err) } - tests := []TLSInfo{ + tests := []*TLSInfo{ {CertFile: tlsinfo.CertFile}, {KeyFile: tlsinfo.KeyFile}, {CertFile: tlsinfo.CertFile, TrustedCAFile: tlsinfo.TrustedCAFile}, @@ -428,8 +440,10 @@ func TestTLSInfoMissingFields(t *testing.T) { } for i, info := range tests { - if _, err = info.ServerConfig(); err == nil { - t.Errorf("#%d: expected non-nil error from ServerConfig()", i) + _, err = info.serverConfig() + + if err == nil { + t.Errorf("#%d: expected non nil error from serverConfig()", i) } if _, err = info.ClientConfig(); err == nil { @@ -445,22 +459,22 @@ func TestTLSInfoParseFuncError(t *testing.T) { } tests := []struct { - info TLSInfo + info *TLSInfo }{ { - info: *tlsinfo, + info: tlsinfo, }, { - info: TLSInfo{CertFile: "", KeyFile: "", TrustedCAFile: tlsinfo.CertFile, EmptyCN: true}, + info: &TLSInfo{CertFile: "", KeyFile: "", TrustedCAFile: tlsinfo.CertFile, EmptyCN: true}, }, } for i, tt := range tests { tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake")) - if _, err = tt.info.ServerConfig(); err == nil { - t.Errorf("#%d: expected non-nil error from ServerConfig()", i) + if _, err = tt.info.serverConfig(); err == nil { + t.Errorf("#%d: expected non-nil error from ReloadableServerConfig()", i) } if _, err = tt.info.ClientConfig(); err == nil { @@ -477,18 +491,18 @@ func TestTLSInfoConfigFuncs(t *testing.T) { } tests := []struct { - info TLSInfo + info *TLSInfo clientAuth tls.ClientAuthType wantCAs bool }{ { - info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, Logger: ln}, + info: &TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, Logger: ln}, clientAuth: tls.NoClientCert, wantCAs: false, }, { - info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, TrustedCAFile: tlsinfo.CertFile, Logger: ln}, + info: &TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, TrustedCAFile: tlsinfo.CertFile, Logger: ln}, clientAuth: tls.RequireAndVerifyClientCert, wantCAs: true, }, @@ -497,9 +511,9 @@ func TestTLSInfoConfigFuncs(t *testing.T) { for i, tt := range tests { tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) - sCfg, err := tt.info.ServerConfig() + sCfg, err := tt.info.serverConfig() if err != nil { - t.Errorf("#%d: expected nil error from ServerConfig(), got non-nil: %v", i, err) + t.Errorf("#%d: expected nil error from ReloadableServerConfig(), got non-nil: %v", i, err) } if tt.wantCAs != (sCfg.ClientCAs != nil) { @@ -567,3 +581,209 @@ func TestSocktOptsEmpty(t *testing.T) { } } } + +func newSerialNumber(t *testing.T) *big.Int { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fail() + } + return serialNumber +} + +func createRootCertificateAuthority(rootCaPath string, oldPem []byte, t *testing.T) (*x509.Certificate, []byte, *ecdsa.PrivateKey) { + serialNumber := newSerialNumber(t) + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + tmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"etcd"}}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * (24 * time.Hour)), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment, + BasicConstraintsValid: true, + IsCA: true, + } + + caBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + t.Fatal(err) + } + caBlocks := [][]byte{caBytes} + if len(oldPem) > 0 { + caBlocks = append(caBlocks, oldPem) + } + marshalCerts(caBlocks, rootCaPath, t) + return ca, caBytes, priv +} + +func generateCerts(privKey *ecdsa.PrivateKey, rootCA *x509.Certificate, dir, suffix string, t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + serialNumber := newSerialNumber(t) + tmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"etcd"}}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * (24 * time.Hour)), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + caBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, rootCA, &priv.PublicKey, privKey) + if err != nil { + t.Fatal(err) + } + marshalCerts([][]byte{caBytes}, path.Join(dir, fmt.Sprintf("cert%s.pem", suffix)), t) + marshalKeys(priv, path.Join(dir, fmt.Sprintf("key%s.pem", suffix)), t) +} + +func marshalCerts(caBytes [][]byte, certPath string, t *testing.T) { + var caPerm bytes.Buffer + for _, caBlock := range caBytes { + err := pem.Encode(&caPerm, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBlock, + }) + if err != nil { + t.Fatal(err) + } + } + ioutil.WriteFile(certPath, caPerm.Bytes(), 0644) +} + +func marshalKeys(privKey *ecdsa.PrivateKey, keyPath string, t *testing.T) { + privBytes, err := x509.MarshalECPrivateKey(privKey) + if err != nil { + t.Fatal(err) + } + + var keyPerm bytes.Buffer + err = pem.Encode(&keyPerm, &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: privBytes, + }) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(keyPath, keyPerm.Bytes(), 0644) +} + +func TestRootCAReload(t *testing.T) { + tmpdir, err := ioutil.TempDir(os.TempDir(), "tlsdir-reload") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpdir) + rootCAPath := path.Join(tmpdir, "ca-cert.pem") + logger := zap.NewExample() + rootCA, caBytes, privKey := createRootCertificateAuthority(rootCAPath, []byte{}, t) + generateCerts(privKey, rootCA, tmpdir, "_old", t) + tlsinfo := &TLSInfo{ + TrustedCAFile: rootCAPath, + CertFile: path.Join(tmpdir, "cert_old.pem"), + KeyFile: path.Join(tmpdir, "key_old.pem"), + ClientCertFile: path.Join(tmpdir, "cert_old.pem"), + ClientKeyFile: path.Join(tmpdir, "key_old.pem"), + Logger: logger, + RefreshDuration: 100 * time.Millisecond, + EnableRootCAReload: true, + } + + ln, err := NewListener("127.0.0.1:0", "https", tlsinfo) + if err != nil { + t.Fatalf("unexpected NewListener error: %v", err) + } + defer ln.Close() + cfg, err := tlsinfo.ClientConfig() + if err != nil { + t.Fatal(err) + } + tr := &http.Transport{TLSClientConfig: cfg} + cli := &http.Client{Transport: tr} + go func() { + cli.Get("https://" + ln.Addr().String()) + }() + + errChan := make(chan error) + go func() { + conn, err := ln.Accept() + if err != nil { + errChan <- err + } + if _, ok := conn.(*tls.Conn); !ok { + errChan <- errors.New("failed to accept *tls.Conn") + } + conn.Close() + errChan <- nil + }() + + select { + case <-time.After(10 * time.Second): + t.Fatalf("timeout accept") + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + } + + // regenerate rootCA and sign new certs + rootCA, _, privKey = createRootCertificateAuthority(rootCAPath, caBytes, t) + generateCerts(privKey, rootCA, tmpdir, "_new", t) + + // give server some time to reload new CA + time.Sleep(time.Second) + + newTlsinfo := &TLSInfo{ + TrustedCAFile: rootCAPath, + CertFile: path.Join(tmpdir, "cert_new.pem"), + KeyFile: path.Join(tmpdir, "key_new.pem"), + ClientCertFile: path.Join(tmpdir, "cert_new.pem"), + ClientKeyFile: path.Join(tmpdir, "key_new.pem"), + Logger: logger, + } + + cfg, err = newTlsinfo.ClientConfig() + if err != nil { + t.Fatal(err) + } + + tr = &http.Transport{TLSClientConfig: cfg} + cli = &http.Client{Transport: tr} + go func() { + cli.Get("https://" + ln.Addr().String()) + }() + + go func() { + conn, err := ln.Accept() + if err != nil { + errChan <- err + } + if _, ok := conn.(*tls.Conn); !ok { + errChan <- errors.New("failed to accept *tls.Conn") + } + conn.Close() + errChan <- nil + }() + + select { + case <-time.After(10 * time.Second): + t.Fatalf("timeout accept") + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + } +} diff --git a/client/pkg/transport/listener_tls.go b/client/pkg/transport/listener_tls.go index 37b17ec275ea..c5acb850f3cb 100644 --- a/client/pkg/transport/listener_tls.go +++ b/client/pkg/transport/listener_tls.go @@ -50,7 +50,7 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.L l.Close() return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String()) } - tlscfg, err := tlsinfo.ServerConfig() + tlscfg, err := tlsinfo.ReloadableServerConfig() if err != nil { return nil, err } diff --git a/client/pkg/transport/timeout_transport.go b/client/pkg/transport/timeout_transport.go index ea16b4c0f869..ec4936bd9618 100644 --- a/client/pkg/transport/timeout_transport.go +++ b/client/pkg/transport/timeout_transport.go @@ -24,7 +24,7 @@ import ( // If read/write on the created connection blocks longer than its time limit, // it will return timeout error. // If read/write timeout is set, transport will not be able to reuse connection. -func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.Duration) (*http.Transport, error) { +func NewTimeoutTransport(info *TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.Duration) (*http.Transport, error) { tr, err := NewTransport(info, dialtimeoutd) if err != nil { return nil, err diff --git a/client/pkg/transport/timeout_transport_test.go b/client/pkg/transport/timeout_transport_test.go index 95079f9b5982..ae3d9e62e897 100644 --- a/client/pkg/transport/timeout_transport_test.go +++ b/client/pkg/transport/timeout_transport_test.go @@ -26,7 +26,7 @@ import ( // TestNewTimeoutTransport tests that NewTimeoutTransport returns a transport // that can dial out timeout connections. func TestNewTimeoutTransport(t *testing.T) { - tr, err := NewTimeoutTransport(TLSInfo{}, time.Hour, time.Hour, time.Hour) + tr, err := NewTimeoutTransport(&TLSInfo{}, time.Hour, time.Hour, time.Hour) if err != nil { t.Fatalf("unexpected NewTimeoutTransport error: %v", err) } diff --git a/client/pkg/transport/tls.go b/client/pkg/transport/tls.go index 8c3a35b140bb..697cd148cd62 100644 --- a/client/pkg/transport/tls.go +++ b/client/pkg/transport/tls.go @@ -23,7 +23,7 @@ import ( // ValidateSecureEndpoints scans the given endpoints against tls info, returning only those // endpoints that could be validated as secure. -func ValidateSecureEndpoints(tlsInfo TLSInfo, eps []string) ([]string, error) { +func ValidateSecureEndpoints(tlsInfo *TLSInfo, eps []string) ([]string, error) { t, err := NewTransport(tlsInfo, 5*time.Second) if err != nil { return nil, err diff --git a/client/pkg/transport/tls_test.go b/client/pkg/transport/tls_test.go index 46af1db6786c..47b1abf3a9cb 100644 --- a/client/pkg/transport/tls_test.go +++ b/client/pkg/transport/tls_test.go @@ -76,7 +76,7 @@ func TestValidateSecureEndpoints(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - secureEps, err := ValidateSecureEndpoints(*tlsInfo, test.endPoints) + secureEps, err := ValidateSecureEndpoints(tlsInfo, test.endPoints) if test.expectedErr != (err != nil) { t.Errorf("Unexpected error, got: %v, want: %v", err, test.expectedErr) } diff --git a/client/pkg/transport/transport.go b/client/pkg/transport/transport.go index 91462dcdb08b..d452051066d5 100644 --- a/client/pkg/transport/transport.go +++ b/client/pkg/transport/transport.go @@ -24,7 +24,7 @@ import ( type unixTransport struct{ *http.Transport } -func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { +func NewTransport(info *TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { cfg, err := info.ClientConfig() if err != nil { return nil, err diff --git a/client/pkg/transport/transport_test.go b/client/pkg/transport/transport_test.go index 315f32cf2dca..60fd11494126 100644 --- a/client/pkg/transport/transport_test.go +++ b/client/pkg/transport/transport_test.go @@ -40,10 +40,10 @@ func TestNewTransportTLSInvalidCipherSuitesTLS12(t *testing.T) { } // make server and client have unmatched cipher suites - srvTLS, cliTLS := *tlsInfo, *tlsInfo + srvTLS, cliTLS := tlsInfo.Clone(), tlsInfo.Clone() srvTLS.CipherSuites, cliTLS.CipherSuites = cipherSuites[:2], cipherSuites[2:] - ln, err := NewListener("127.0.0.1:0", "https", &srvTLS) + ln, err := NewListener("127.0.0.1:0", "https", srvTLS) if err != nil { t.Fatalf("unexpected NewListener error: %v", err) } diff --git a/contrib/raftexample/raft.go b/contrib/raftexample/raft.go index b1618e1c1b70..a90172e984ab 100644 --- a/contrib/raftexample/raft.go +++ b/contrib/raftexample/raft.go @@ -17,6 +17,7 @@ package main import ( "context" "fmt" + "go.etcd.io/etcd/client/pkg/v3/transport" "log" "net/http" "net/url" @@ -313,6 +314,7 @@ func (rc *raftNode) startRaft() { ServerStats: stats.NewServerStats("", ""), LeaderStats: stats.NewLeaderStats(zap.NewExample(), strconv.Itoa(rc.id)), ErrorC: make(chan error), + TLSInfo: &transport.TLSInfo{}, } rc.transport.Start() diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 72a0c7483d23..2cdf93ad02d3 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -133,7 +133,7 @@ type ServerConfig struct { Logger *zap.Logger From url.URL To url.URL - TLSInfo transport.TLSInfo + TLSInfo *transport.TLSInfo DialTimeout time.Duration BufferSize int RetryInterval time.Duration @@ -147,7 +147,7 @@ type server struct { to url.URL toPort int - tlsInfo transport.TLSInfo + tlsInfo *transport.TLSInfo dialTimeout time.Duration bufferSize int @@ -250,7 +250,7 @@ func NewServer(cfg ServerConfig) Server { var ln net.Listener if !s.tlsInfo.Empty() { - ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo) + ln, err = transport.NewListener(addr, s.from.Scheme, s.tlsInfo) } else { ln, err = net.Listen(s.from.Scheme, addr) } @@ -639,6 +639,7 @@ func (s *server) Close() (err error) { } s.lg.Sync() s.listenerMu.Unlock() + s.tlsInfo.Close() }) s.closeWg.Wait() return err @@ -973,7 +974,7 @@ func (s *server) ResetListener() error { var ln net.Listener var err error if !s.tlsInfo.Empty() { - ln, err = transport.NewListener(s.from.Host, s.from.Scheme, &s.tlsInfo) + ln, err = transport.NewListener(s.from.Host, s.from.Scheme, s.tlsInfo) } else { ln, err = net.Listen(s.from.Scheme, s.from.Host) } diff --git a/pkg/proxy/server_test.go b/pkg/proxy/server_test.go index 999c4304a5d7..5a7a32d74b53 100644 --- a/pkg/proxy/server_test.go +++ b/pkg/proxy/server_test.go @@ -50,7 +50,7 @@ func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { lg := zaptest.NewLogger(t) srcAddr, dstAddr := newUnixAddr(), newUnixAddr() if scheme == "tcp" { - ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{}) + ln1, ln2 := listen(t, "tcp", "localhost:0", &transport.TLSInfo{}), listen(t, "tcp", "localhost:0", &transport.TLSInfo{}) srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String() ln1.Close() ln2.Close() @@ -65,9 +65,10 @@ func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { defer ln.Close() cfg := ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, } if secure { cfg.TLSInfo = tlsInfo @@ -162,9 +163,9 @@ func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { } } -func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo { +func createTLSInfo(lg *zap.Logger, secure bool) *transport.TLSInfo { if secure { - return transport.TLSInfo{ + return &transport.TLSInfo{ KeyFile: "../../tests/fixtures/server.key.insecure", CertFile: "../../tests/fixtures/server.crt", TrustedCAFile: "../../tests/fixtures/ca.crt", @@ -172,7 +173,7 @@ func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo { Logger: lg, } } - return transport.TLSInfo{Logger: lg} + return &transport.TLSInfo{Logger: lg} } func TestServer_Unix_Insecure_DelayAccept(t *testing.T) { testServerDelayAccept(t, false) } @@ -190,9 +191,10 @@ func testServerDelayAccept(t *testing.T, secure bool) { defer ln.Close() cfg := ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, } if secure { cfg.TLSInfo = tlsInfo @@ -242,13 +244,14 @@ func TestServer_PauseTx(t *testing.T) { os.RemoveAll(srcAddr) os.RemoveAll(dstAddr) }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) + ln := listen(t, scheme, dstAddr, &transport.TLSInfo{}) defer ln.Close() p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, }) waitForServer(t, p) @@ -258,7 +261,7 @@ func TestServer_PauseTx(t *testing.T) { p.PauseTx() data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) recvc := make(chan []byte, 1) go func() { @@ -291,13 +294,14 @@ func TestServer_ModifyTx_corrupt(t *testing.T) { os.RemoveAll(srcAddr) os.RemoveAll(dstAddr) }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) + ln := listen(t, scheme, dstAddr, &transport.TLSInfo{}) defer ln.Close() p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, }) waitForServer(t, p) @@ -309,13 +313,13 @@ func TestServer_ModifyTx_corrupt(t *testing.T) { return d }) data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) if d := receive(t, ln); bytes.Equal(d, data) { t.Fatalf("expected corrupted data, got %q", string(d)) } p.UnmodifyTx() - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) if d := receive(t, ln); !bytes.Equal(d, data) { t.Fatalf("expected uncorrupted data, got %q", string(d)) } @@ -329,13 +333,14 @@ func TestServer_ModifyTx_packet_loss(t *testing.T) { os.RemoveAll(srcAddr) os.RemoveAll(dstAddr) }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) + ln := listen(t, scheme, dstAddr, &transport.TLSInfo{}) defer ln.Close() p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, }) waitForServer(t, p) @@ -348,13 +353,13 @@ func TestServer_ModifyTx_packet_loss(t *testing.T) { return d[:half:half] }) data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) if d := receive(t, ln); bytes.Equal(d, data) { t.Fatalf("expected corrupted data, got %q", string(d)) } p.UnmodifyTx() - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) if d := receive(t, ln); !bytes.Equal(d, data) { t.Fatalf("expected uncorrupted data, got %q", string(d)) } @@ -368,13 +373,14 @@ func TestServer_BlackholeTx(t *testing.T) { os.RemoveAll(srcAddr) os.RemoveAll(dstAddr) }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) + ln := listen(t, scheme, dstAddr, &transport.TLSInfo{}) defer ln.Close() p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, }) waitForServer(t, p) @@ -384,7 +390,7 @@ func TestServer_BlackholeTx(t *testing.T) { p.BlackholeTx() data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) recvc := make(chan []byte, 1) go func() { @@ -401,7 +407,7 @@ func TestServer_BlackholeTx(t *testing.T) { // expect different data, old data dropped data[0]++ - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) select { case d := <-recvc: @@ -421,13 +427,14 @@ func TestServer_Shutdown(t *testing.T) { os.RemoveAll(srcAddr) os.RemoveAll(dstAddr) }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) + ln := listen(t, scheme, dstAddr, &transport.TLSInfo{}) defer ln.Close() p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, }) waitForServer(t, p) @@ -439,7 +446,7 @@ func TestServer_Shutdown(t *testing.T) { time.Sleep(200 * time.Millisecond) data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) if d := receive(t, ln); !bytes.Equal(d, data) { t.Fatalf("expected %q, got %q", string(data), string(d)) } @@ -454,13 +461,14 @@ func TestServer_ShutdownListener(t *testing.T) { os.RemoveAll(dstAddr) }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) + ln := listen(t, scheme, dstAddr, &transport.TLSInfo{}) defer ln.Close() p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, }) waitForServer(t, p) @@ -471,11 +479,11 @@ func TestServer_ShutdownListener(t *testing.T) { ln.Close() time.Sleep(200 * time.Millisecond) - ln = listen(t, scheme, dstAddr, transport.TLSInfo{}) + ln = listen(t, scheme, dstAddr, &transport.TLSInfo{}) defer ln.Close() data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) + send(t, data, scheme, srcAddr, &transport.TLSInfo{}) if d := receive(t, ln); !bytes.Equal(d, data) { t.Fatalf("expected %q, got %q", string(data), string(d)) } @@ -488,7 +496,7 @@ func TestServerHTTP_Secure_DelayRx(t *testing.T) { testServerHTTP(t, true, fal func testServerHTTP(t *testing.T, secure, delayTx bool) { lg := zaptest.NewLogger(t) scheme := "tcp" - ln1, ln2 := listen(t, scheme, "localhost:0", transport.TLSInfo{}), listen(t, scheme, "localhost:0", transport.TLSInfo{}) + ln1, ln2 := listen(t, scheme, "localhost:0", &transport.TLSInfo{}), listen(t, scheme, "localhost:0", &transport.TLSInfo{}) srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String() ln1.Close() ln2.Close() @@ -507,7 +515,7 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) { tlsInfo := createTLSInfo(lg, secure) var tlsConfig *tls.Config if secure { - _, err := tlsInfo.ServerConfig() + _, err := tlsInfo.ReloadableServerConfig() if err != nil { t.Fatal(err) } @@ -535,9 +543,10 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) { time.Sleep(200 * time.Millisecond) cfg := ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: lg, + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + TLSInfo: &transport.TLSInfo{}, } if secure { cfg.TLSInfo = tlsInfo @@ -634,10 +643,10 @@ func newUnixAddr() string { return addr } -func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) { +func listen(t *testing.T, scheme, addr string, tlsInfo *transport.TLSInfo) (ln net.Listener) { var err error if !tlsInfo.Empty() { - ln, err = transport.NewListener(addr, scheme, &tlsInfo) + ln, err = transport.NewListener(addr, scheme, tlsInfo) } else { ln, err = net.Listen(scheme, addr) } @@ -647,7 +656,7 @@ func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln ne return ln } -func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo) { +func send(t *testing.T, data []byte, scheme, addr string, tlsInfo *transport.TLSInfo) { var out net.Conn var err error if !tlsInfo.Empty() { diff --git a/server/config/config.go b/server/config/config.go index 5206b3dc5f98..74685c918503 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -72,7 +72,7 @@ type ServerConfig struct { InitialPeerURLsMap types.URLsMap InitialClusterToken string NewCluster bool - PeerTLSInfo transport.TLSInfo + PeerTLSInfo *transport.TLSInfo CORS map[string]struct{} diff --git a/server/embed/config.go b/server/embed/config.go index af4e2524182a..54c4bb2fc980 100644 --- a/server/embed/config.go +++ b/server/embed/config.go @@ -213,9 +213,9 @@ type Config struct { LPUrls, LCUrls []url.URL APUrls, ACUrls []url.URL - ClientTLSInfo transport.TLSInfo + ClientTLSInfo *transport.TLSInfo ClientAutoTLS bool - PeerTLSInfo transport.TLSInfo + PeerTLSInfo *transport.TLSInfo PeerAutoTLS bool // SelfSignedCertValidity specifies the validity period of the client and peer certificates // that are automatically generated by etcd when you specify ClientAutoTLS and PeerAutoTLS, @@ -528,7 +528,6 @@ func NewConfig() *Config { ExperimentalCompactHashCheckTime: time.Minute, V2Deprecation: config.V2_DEPR_DEFAULT, - DiscoveryCfg: v3discovery.DiscoveryConfig{ ConfigSpec: clientv3.ConfigSpec{ DialTimeout: DefaultDiscoveryDialTimeout, @@ -540,6 +539,8 @@ func NewConfig() *Config { Auth: &clientv3.AuthConfig{}, }, }, + PeerTLSInfo: &transport.TLSInfo{}, + ClientTLSInfo: &transport.TLSInfo{}, } cfg.InitialCluster = cfg.InitialClusterFromName(cfg.Name) return cfg @@ -637,8 +638,8 @@ func (cfg *configYAML) configFromFile(path string) error { tls.ClientCertAuth = ysc.CertAuth tls.TrustedCAFile = ysc.TrustedCAFile } - copySecurityDetails(&cfg.ClientTLSInfo, &cfg.ClientSecurityJSON) - copySecurityDetails(&cfg.PeerTLSInfo, &cfg.PeerSecurityJSON) + copySecurityDetails(cfg.ClientTLSInfo, &cfg.ClientSecurityJSON) + copySecurityDetails(cfg.PeerTLSInfo, &cfg.PeerSecurityJSON) cfg.ClientAutoTLS = cfg.ClientSecurityJSON.AutoTLS cfg.PeerAutoTLS = cfg.PeerSecurityJSON.AutoTLS if cfg.SelfSignedCertValidity == 0 { @@ -922,7 +923,7 @@ func (cfg *Config) ClientSelfCert() (err error) { if err != nil { return err } - return updateCipherSuites(&cfg.ClientTLSInfo, cfg.CipherSuites) + return updateCipherSuites(cfg.ClientTLSInfo, cfg.CipherSuites) } func (cfg *Config) PeerSelfCert() (err error) { @@ -941,7 +942,7 @@ func (cfg *Config) PeerSelfCert() (err error) { if err != nil { return err } - return updateCipherSuites(&cfg.PeerTLSInfo, cfg.CipherSuites) + return updateCipherSuites(cfg.PeerTLSInfo, cfg.CipherSuites) } // UpdateDefaultClusterFromName updates cluster advertise URLs with, if available, default host, diff --git a/server/embed/config_test.go b/server/embed/config_test.go index bb5c088805fd..ef6f80c68bd3 100644 --- a/server/embed/config_test.go +++ b/server/embed/config_test.go @@ -72,10 +72,10 @@ func TestConfigFileOtherFields(t *testing.T) { t.Fatal(err) } - if !ctls.equals(&cfg.ClientTLSInfo) { + if !ctls.equals(cfg.ClientTLSInfo) { t.Errorf("ClientTLS = %v, want %v", cfg.ClientTLSInfo, ctls) } - if !ptls.equals(&cfg.PeerTLSInfo) { + if !ptls.equals(cfg.PeerTLSInfo) { t.Errorf("PeerTLS = %v, want %v", cfg.PeerTLSInfo, ptls) } diff --git a/server/embed/etcd.go b/server/embed/etcd.go index 8a9ed897ff10..7374db77084a 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -492,7 +492,7 @@ func (e *Etcd) Err() <-chan error { } func configurePeerListeners(cfg *Config) (peers []*peerListener, err error) { - if err = updateCipherSuites(&cfg.PeerTLSInfo, cfg.CipherSuites); err != nil { + if err = updateCipherSuites(cfg.PeerTLSInfo, cfg.CipherSuites); err != nil { return nil, err } if err = cfg.PeerSelfCert(); err != nil { @@ -536,7 +536,7 @@ func configurePeerListeners(cfg *Config) (peers []*peerListener, err error) { } peers[i] = &peerListener{close: func(context.Context) error { return nil }} peers[i].Listener, err = transport.NewListenerWithOpts(u.Host, u.Scheme, - transport.WithTLSInfo(&cfg.PeerTLSInfo), + transport.WithTLSInfo(cfg.PeerTLSInfo), transport.WithSocketOpts(&cfg.SocketOpts), transport.WithTimeout(rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout), ) @@ -604,7 +604,7 @@ func (e *Etcd) servePeers() (err error) { } func configureClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err error) { - if err = updateCipherSuites(&cfg.ClientTLSInfo, cfg.CipherSuites); err != nil { + if err = updateCipherSuites(cfg.ClientTLSInfo, cfg.CipherSuites); err != nil { return nil, err } if err = cfg.ClientSelfCert(); err != nil { @@ -726,7 +726,7 @@ func (e *Etcd) serveClients() (err error) { // start client servers in each goroutine for _, sctx := range e.sctxs { go func(s *serveCtx) { - e.errHandler(s.serve(e.Server, &e.cfg.ClientTLSInfo, mux, e.errHandler, gopts...)) + e.errHandler(s.serve(e.Server, e.cfg.ClientTLSInfo, mux, e.errHandler, gopts...)) }(sctx) } return nil @@ -743,7 +743,7 @@ func (e *Etcd) serveMetrics() (err error) { etcdhttp.HandleHealth(e.cfg.logger, metricsMux, e.Server) for _, murl := range e.cfg.ListenMetricsUrls { - tlsInfo := &e.cfg.ClientTLSInfo + tlsInfo := e.cfg.ClientTLSInfo if murl.Scheme == "http" { tlsInfo = nil } diff --git a/server/embed/serve.go b/server/embed/serve.go index 7fff618a687c..2ff6310a5d9a 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -166,7 +166,7 @@ func (sctx *serveCtx) serve( } if sctx.secure { - tlscfg, tlsErr := tlsinfo.ServerConfig() + tlscfg, tlsErr := tlsinfo.ReloadableServerConfig() if tlsErr != nil { return tlsErr } diff --git a/server/etcdmain/config.go b/server/etcdmain/config.go index b14191a95cb8..0590ae051352 100644 --- a/server/etcdmain/config.go +++ b/server/etcdmain/config.go @@ -199,11 +199,13 @@ func newConfig() *config { fs.StringVar(&cfg.ec.ClientTLSInfo.CRLFile, "client-crl-file", "", "Path to the client certificate revocation list file.") fs.StringVar(&cfg.ec.ClientTLSInfo.AllowedHostname, "client-cert-allowed-hostname", "", "Allowed TLS hostname for client cert authentication.") fs.StringVar(&cfg.ec.ClientTLSInfo.TrustedCAFile, "trusted-ca-file", "", "Path to the client server TLS trusted CA cert file.") + fs.BoolVar(&cfg.ec.ClientTLSInfo.EnableRootCAReload, "client-root-ca-reload", false, "Enable client server TLS root CA dynamic reload to support root CA rotation") fs.BoolVar(&cfg.ec.ClientAutoTLS, "auto-tls", false, "Client TLS using generated certificates") fs.StringVar(&cfg.ec.PeerTLSInfo.CertFile, "peer-cert-file", "", "Path to the peer server TLS cert file.") fs.StringVar(&cfg.ec.PeerTLSInfo.KeyFile, "peer-key-file", "", "Path to the peer server TLS key file.") fs.StringVar(&cfg.ec.PeerTLSInfo.ClientCertFile, "peer-client-cert-file", "", "Path to an explicit peer client TLS cert file otherwise peer cert file will be used when client auth is required.") fs.StringVar(&cfg.ec.PeerTLSInfo.ClientKeyFile, "peer-client-key-file", "", "Path to an explicit peer client TLS key file otherwise peer key file will be used when client auth is required.") + fs.BoolVar(&cfg.ec.PeerTLSInfo.EnableRootCAReload, "peer-client-root-ca-reload", false, "Enable peer client TLS root CA dynamic reload to support root CA rotation") fs.BoolVar(&cfg.ec.PeerTLSInfo.ClientCertAuth, "peer-client-cert-auth", false, "Enable peer client cert authentication.") fs.StringVar(&cfg.ec.PeerTLSInfo.TrustedCAFile, "peer-trusted-ca-file", "", "Path to the peer server TLS trusted CA file.") fs.BoolVar(&cfg.ec.PeerAutoTLS, "peer-auto-tls", false, "Peer TLS using generated certificates") diff --git a/server/etcdmain/grpc_proxy.go b/server/etcdmain/grpc_proxy.go index aab2d6f83e8f..dd09fa36ad15 100644 --- a/server/etcdmain/grpc_proxy.go +++ b/server/etcdmain/grpc_proxy.go @@ -199,7 +199,7 @@ func startGRPCProxy(cmd *cobra.Command, args []string) { if err != nil { log.Fatal(err) } - tlsinfo = &autoTLS + tlsinfo = autoTLS } if tlsinfo != nil { lg.Info("gRPC proxy server TLS", zap.String("tls-info", fmt.Sprintf("%+v", tlsinfo))) @@ -514,7 +514,7 @@ func mustHTTPListener(lg *zap.Logger, m cmux.CMux, tlsinfo *transport.TLSInfo, c return srvhttp, m.Match(cmux.HTTP1()) } - srvTLS, err := tlsinfo.ServerConfig() + srvTLS, err := tlsinfo.ReloadableServerConfig() if err != nil { lg.Fatal("failed to set up TLS", zap.Error(err)) } diff --git a/server/etcdmain/util.go b/server/etcdmain/util.go index 0bd23e9e591b..e6a73bc3fa3b 100644 --- a/server/etcdmain/util.go +++ b/server/etcdmain/util.go @@ -60,7 +60,7 @@ func discoverEndpoints(lg *zap.Logger, dns string, ca string, insecure bool, ser ) } - endpoints, err = transport.ValidateSecureEndpoints(tlsInfo, endpoints) + endpoints, err = transport.ValidateSecureEndpoints(&tlsInfo, endpoints) if err != nil { if lg != nil { lg.Warn( diff --git a/server/etcdserver/api/rafthttp/functional_test.go b/server/etcdserver/api/rafthttp/functional_test.go index c6314b3dcd6f..6526f60b07a9 100644 --- a/server/etcdserver/api/rafthttp/functional_test.go +++ b/server/etcdserver/api/rafthttp/functional_test.go @@ -16,6 +16,7 @@ package rafthttp import ( "context" + "go.etcd.io/etcd/client/pkg/v3/transport" "net/http/httptest" "reflect" "testing" @@ -36,6 +37,7 @@ func TestSendMessage(t *testing.T) { Raft: &fakeRaft{}, ServerStats: newServerStats(), LeaderStats: stats.NewLeaderStats(zaptest.NewLogger(t), "1"), + TLSInfo: &transport.TLSInfo{}, } tr.Start() srv := httptest.NewServer(tr.Handler()) @@ -50,6 +52,7 @@ func TestSendMessage(t *testing.T) { Raft: p, ServerStats: newServerStats(), LeaderStats: stats.NewLeaderStats(zaptest.NewLogger(t), "2"), + TLSInfo: &transport.TLSInfo{}, } tr2.Start() srv2 := httptest.NewServer(tr2.Handler()) @@ -94,6 +97,7 @@ func TestSendMessageWhenStreamIsBroken(t *testing.T) { Raft: &fakeRaft{}, ServerStats: newServerStats(), LeaderStats: stats.NewLeaderStats(zaptest.NewLogger(t), "1"), + TLSInfo: &transport.TLSInfo{}, } tr.Start() srv := httptest.NewServer(tr.Handler()) @@ -108,6 +112,7 @@ func TestSendMessageWhenStreamIsBroken(t *testing.T) { Raft: p, ServerStats: newServerStats(), LeaderStats: stats.NewLeaderStats(zaptest.NewLogger(t), "2"), + TLSInfo: &transport.TLSInfo{}, } tr2.Start() srv2 := httptest.NewServer(tr2.Handler()) diff --git a/server/etcdserver/api/rafthttp/transport.go b/server/etcdserver/api/rafthttp/transport.go index fa3011cb39ac..627513c3e9cf 100644 --- a/server/etcdserver/api/rafthttp/transport.go +++ b/server/etcdserver/api/rafthttp/transport.go @@ -102,7 +102,7 @@ type Transport struct { // a distinct rate limiter is created per every peer (default value: 10 events/sec) DialRetryFrequency rate.Limit - TLSInfo transport.TLSInfo // TLS information used when creating connection + TLSInfo *transport.TLSInfo // TLS information used when creating connection ID types.ID // local member ID URLs types.URLs // local peer URLs diff --git a/server/etcdserver/api/rafthttp/util.go b/server/etcdserver/api/rafthttp/util.go index 91bc6884e4bc..2da4ec9e6a57 100644 --- a/server/etcdserver/api/rafthttp/util.go +++ b/server/etcdserver/api/rafthttp/util.go @@ -44,7 +44,7 @@ func NewListener(u url.URL, tlsinfo *transport.TLSInfo) (net.Listener, error) { // NewRoundTripper returns a roundTripper used to send requests // to rafthttp listener of remote peers. -func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) { +func NewRoundTripper(tlsInfo *transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) { // It uses timeout transport to pair with remote timeout listeners. // It sets no read/write timeout, because message in requests may // take long time to write out before reading out the response. @@ -56,7 +56,7 @@ func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http // Read/write timeout is set for stream roundTripper to promptly // find out broken status, which minimizes the number of messages // sent on broken connection. -func newStreamRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) { +func newStreamRoundTripper(tlsInfo *transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) { return transport.NewTimeoutTransport(tlsInfo, dialTimeout, ConnReadTimeout, ConnWriteTimeout) } diff --git a/server/etcdserver/api/v2discovery/discovery.go b/server/etcdserver/api/v2discovery/discovery.go index 9f1bc0adf50c..7ab51e0530b5 100644 --- a/server/etcdserver/api/v2discovery/discovery.go +++ b/server/etcdserver/api/v2discovery/discovery.go @@ -133,7 +133,7 @@ func newDiscovery(lg *zap.Logger, durl, dproxyurl string, id types.ID) (*discove } // TODO: add ResponseHeaderTimeout back when watch on discovery service writes header early - tr, err := transport.NewTransport(transport.TLSInfo{}, 30*time.Second) + tr, err := transport.NewTransport(&transport.TLSInfo{}, 30*time.Second) if err != nil { return nil, err } diff --git a/tests/e2e/discovery_test.go b/tests/e2e/discovery_test.go index 8ace067d5c2f..0774381c84cd 100644 --- a/tests/e2e/discovery_test.go +++ b/tests/e2e/discovery_test.go @@ -83,9 +83,9 @@ func testClusterUsingDiscovery(t *testing.T, size int, peerTLS bool) { } func MustNewHTTPClient(t testutil.TB, eps []string, tls *transport.TLSInfo) client.Client { - cfgtls := transport.TLSInfo{} + cfgtls := &transport.TLSInfo{} if tls != nil { - cfgtls = *tls + cfgtls = tls } cfg := client.Config{Transport: mustNewTransport(t, cfgtls), Endpoints: eps} c, err := client.New(cfg) @@ -95,7 +95,7 @@ func MustNewHTTPClient(t testutil.TB, eps []string, tls *transport.TLSInfo) clie return c } -func mustNewTransport(t testutil.TB, tlsInfo transport.TLSInfo) *http.Transport { +func mustNewTransport(t testutil.TB, tlsInfo *transport.TLSInfo) *http.Transport { // tick in integration test is short, so 1s dial timeout could play well. tr, err := transport.NewTimeoutTransport(tlsInfo, time.Second, rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout) if err != nil { diff --git a/tests/framework/integration.go b/tests/framework/integration.go index 631b7263e696..e5de5e1e4546 100644 --- a/tests/framework/integration.go +++ b/tests/framework/integration.go @@ -71,7 +71,7 @@ func tlsInfo(t testing.TB, cfg config.TLSConfig) (*transport.TLSInfo, error) { if err != nil { return nil, fmt.Errorf("failed to generate cert: %s", err) } - return &tls, nil + return tls, nil case config.ManualTLS: return &integration.TestTLSInfo, nil default: diff --git a/tests/framework/integration/cluster.go b/tests/framework/integration/cluster.go index d9a21645348d..b17484895c95 100644 --- a/tests/framework/integration/cluster.go +++ b/tests/framework/integration/cluster.go @@ -651,7 +651,9 @@ func MustNewMember(t testutil.TB, mcfg MemberConfig) *Member { m.NewCluster = true m.BootstrapTimeout = 10 * time.Millisecond if m.PeerTLSInfo != nil { - m.ServerConfig.PeerTLSInfo = *m.PeerTLSInfo + m.ServerConfig.PeerTLSInfo = m.PeerTLSInfo.Clone() + } else { + m.ServerConfig.PeerTLSInfo = &transport.TLSInfo{} } m.ElectionTicks = ElectionTicks m.InitialElectionTickAdvance = true @@ -905,7 +907,7 @@ func (m *Member) Clone(t testutil.TB) *Member { return mm } -// Launch starts a member based on ServerConfig, PeerListeners +// Launch starts a member based on ReloadableServerConfig, PeerListeners // and ClientListeners. func (m *Member) Launch() error { m.Logger.Info( @@ -924,7 +926,7 @@ func (m *Member) Launch() error { var peerTLScfg *tls.Config if m.PeerTLSInfo != nil && !m.PeerTLSInfo.Empty() { - if peerTLScfg, err = m.PeerTLSInfo.ServerConfig(); err != nil { + if peerTLScfg, err = m.PeerTLSInfo.ReloadableServerConfig(); err != nil { return err } } @@ -934,7 +936,7 @@ func (m *Member) Launch() error { tlscfg *tls.Config ) if m.ClientTLSInfo != nil && !m.ClientTLSInfo.Empty() { - tlscfg, err = m.ClientTLSInfo.ServerConfig() + tlscfg, err = m.ClientTLSInfo.ReloadableServerConfig() if err != nil { return err } @@ -1008,7 +1010,7 @@ func (m *Member) Launch() error { hs.Start() } else { info := m.ClientTLSInfo - hs.TLS, err = info.ServerConfig() + hs.TLS, err = info.ReloadableServerConfig() if err != nil { return err } @@ -1271,7 +1273,7 @@ func (m *Member) Terminate(t testutil.TB) { // Metric gets the metric value for a member func (m *Member) Metric(metricName string, expectLabels ...string) (string, error) { cfgtls := transport.TLSInfo{} - tr, err := transport.NewTimeoutTransport(cfgtls, time.Second, time.Second, time.Second) + tr, err := transport.NewTimeoutTransport(&cfgtls, time.Second, time.Second, time.Second) if err != nil { return "", err } diff --git a/tests/functional/agent/handler.go b/tests/functional/agent/handler.go index 6d6023064c40..cb8a540093bf 100644 --- a/tests/functional/agent/handler.go +++ b/tests/functional/agent/handler.go @@ -17,6 +17,7 @@ package agent import ( "errors" "fmt" + "go.etcd.io/etcd/client/pkg/v3/transport" "net/url" "os" "os/exec" @@ -227,9 +228,10 @@ func (srv *Server) startProxy() error { srv.lg.Info("starting proxy on client traffic", zap.String("url", advertiseClientURL.String())) srv.advertiseClientPortToProxy[advertiseClientURLPort] = proxy.NewServer(proxy.ServerConfig{ - Logger: srv.lg, - From: *advertiseClientURL, - To: *listenClientURL, + Logger: srv.lg, + From: *advertiseClientURL, + To: *listenClientURL, + TLSInfo: &transport.TLSInfo{}, }) select { case err = <-srv.advertiseClientPortToProxy[advertiseClientURLPort].Error(): @@ -257,9 +259,10 @@ func (srv *Server) startProxy() error { srv.lg.Info("starting proxy on peer traffic", zap.String("url", advertisePeerURL.String())) srv.advertisePeerPortToProxy[advertisePeerURLPort] = proxy.NewServer(proxy.ServerConfig{ - Logger: srv.lg, - From: *advertisePeerURL, - To: *listenPeerURL, + Logger: srv.lg, + From: *advertisePeerURL, + To: *listenPeerURL, + TLSInfo: &transport.TLSInfo{}, }) select { case err = <-srv.advertisePeerPortToProxy[advertisePeerURLPort].Error(): diff --git a/tests/integration/clientv3/metrics_test.go b/tests/integration/clientv3/metrics_test.go index 26b24b29530c..9740259e0929 100644 --- a/tests/integration/clientv3/metrics_test.go +++ b/tests/integration/clientv3/metrics_test.go @@ -145,7 +145,7 @@ func sumCountersForMetricAndLabels(t *testing.T, url string, metricName string, } func getHTTPBodyAsLines(t *testing.T, url string) []string { - cfgtls := transport.TLSInfo{} + cfgtls := &transport.TLSInfo{} tr, err := transport.NewTransport(cfgtls, time.Second) if err != nil { t.Fatalf("Error getting transport: %v", err) diff --git a/tests/integration/embed/embed_test.go b/tests/integration/embed/embed_test.go index 3733684d2a20..507d534cce78 100644 --- a/tests/integration/embed/embed_test.go +++ b/tests/integration/embed/embed_test.go @@ -136,8 +136,8 @@ func testEmbedEtcdGracefulStop(t *testing.T, secure bool) { cfg := embed.NewConfig() if secure { - cfg.ClientTLSInfo = testTLSInfo - cfg.PeerTLSInfo = testTLSInfo + cfg.ClientTLSInfo = testTLSInfo.Clone() + cfg.PeerTLSInfo = testTLSInfo.Clone() } urls := newEmbedURLs(secure, 2) diff --git a/tests/integration/lazy_cluster.go b/tests/integration/lazy_cluster.go index 1d16d2d38184..a5204c26b2d0 100644 --- a/tests/integration/lazy_cluster.go +++ b/tests/integration/lazy_cluster.go @@ -79,7 +79,7 @@ func (lc *lazyCluster) mustLazyInit() { lc.once.Do(func() { lc.tb.Logf("LazyIniting ...") var err error - lc.transport, err = transport.NewTransport(transport.TLSInfo{}, time.Second) + lc.transport, err = transport.NewTransport(&transport.TLSInfo{}, time.Second) if err != nil { log.Fatal(err) } diff --git a/tests/integration/metrics_test.go b/tests/integration/metrics_test.go index 59ce0d1d3772..9741003f94f5 100644 --- a/tests/integration/metrics_test.go +++ b/tests/integration/metrics_test.go @@ -186,7 +186,7 @@ func TestMetricsHealth(t *testing.T) { clus := integration.NewCluster(t, &integration.ClusterConfig{Size: 1}) defer clus.Terminate(t) - tr, err := transport.NewTransport(transport.TLSInfo{}, 5*time.Second) + tr, err := transport.NewTransport(&transport.TLSInfo{}, 5*time.Second) if err != nil { t.Fatal(err) } diff --git a/tests/integration/root_ca_rotation_test.go b/tests/integration/root_ca_rotation_test.go new file mode 100644 index 000000000000..0c9f110fad13 --- /dev/null +++ b/tests/integration/root_ca_rotation_test.go @@ -0,0 +1,233 @@ +package integration + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "os" + "path" + "testing" + "time" + + "go.etcd.io/etcd/client/pkg/v3/transport" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/tests/v3/framework/integration" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +func newSerialNumber(t *testing.T) *big.Int { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fail() + } + return serialNumber +} + +func createRootCertificateAuthority(rootCaPath string, oldPem []byte, t *testing.T) (*x509.Certificate, []byte, *ecdsa.PrivateKey) { + serialNumber := newSerialNumber(t) + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + tmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"etcd"}}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * (24 * time.Hour)), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment, + BasicConstraintsValid: true, + IsCA: true, + } + + caBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + t.Fatal(err) + } + caBlocks := [][]byte{caBytes} + if len(oldPem) > 0 { + caBlocks = append(caBlocks, oldPem) + } + marshalCerts(caBlocks, rootCaPath, t) + return ca, caBytes, priv +} + +func generateCerts(privKey *ecdsa.PrivateKey, rootCA *x509.Certificate, dir, suffix string, t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + serialNumber := newSerialNumber(t) + tmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"etcd"}}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * (24 * time.Hour)), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + caBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, rootCA, &priv.PublicKey, privKey) + if err != nil { + t.Fatal(err) + } + marshalCerts([][]byte{caBytes}, path.Join(dir, fmt.Sprintf("cert%s.pem", suffix)), t) + marshalKeys(priv, path.Join(dir, fmt.Sprintf("key%s.pem", suffix)), t) +} + +func marshalCerts(caBytes [][]byte, certPath string, t *testing.T) { + var caPerm bytes.Buffer + for _, caBlock := range caBytes { + err := pem.Encode(&caPerm, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBlock, + }) + if err != nil { + t.Fatal(err) + } + } + ioutil.WriteFile(certPath, caPerm.Bytes(), 0644) +} + +func marshalKeys(privKey *ecdsa.PrivateKey, keyPath string, t *testing.T) { + privBytes, err := x509.MarshalECPrivateKey(privKey) + if err != nil { + t.Fatal(err) + } + + var keyPerm bytes.Buffer + err = pem.Encode(&keyPerm, &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: privBytes, + }) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(keyPath, keyPerm.Bytes(), 0644) +} + +func TestRootCARotation(t *testing.T) { + integration.BeforeTest(t) + + tmpdir, err := ioutil.TempDir(os.TempDir(), "tlsdir-integration-reload") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpdir) + rootCAPath := path.Join(tmpdir, "ca-cert.pem") + logger := zap.NewExample() + rootCA, caBytes, privKey := createRootCertificateAuthority(rootCAPath, []byte{}, t) + generateCerts(privKey, rootCA, tmpdir, "_itest_old", t) + + tlsInfo := &transport.TLSInfo{ + TrustedCAFile: rootCAPath, + CertFile: path.Join(tmpdir, "cert_itest_old.pem"), + KeyFile: path.Join(tmpdir, "key_itest_old.pem"), + ClientCertFile: path.Join(tmpdir, "cert_itest_old.pem"), + ClientKeyFile: path.Join(tmpdir, "key_itest_old.pem"), + Logger: logger, + RefreshDuration: 100 * time.Millisecond, + EnableRootCAReload: true, + } + defer tlsInfo.Close() + + cluster := integration.NewCluster( + t, + &integration.ClusterConfig{ + Size: 1, + ClientTLS: tlsInfo, + }, + ) + defer cluster.Terminate(t) + + cc, err := tlsInfo.ClientConfig() + if err != nil { + t.Fatal(err) + } + + cli, cerr := integration.NewClient(t, clientv3.Config{ + Endpoints: []string{cluster.Members[0].GRPCURL()}, + DialTimeout: time.Second, + DialOptions: []grpc.DialOption{grpc.WithBlock()}, + TLS: cc, + }) + + if cli != nil { + cli.Close() + } + + if cerr != nil { + t.Fatalf("expected TLS handshake success, got %v", cerr) + } + + // regenerate rootCA and sign new certs + rootCA, _, privKey = createRootCertificateAuthority(rootCAPath, caBytes, t) + generateCerts(privKey, rootCA, tmpdir, "_itest_new", t) + + // give server some time to reload new CA + time.Sleep(time.Second) + + newClientTlsinfo := &transport.TLSInfo{ + TrustedCAFile: rootCAPath, + CertFile: path.Join(tmpdir, "cert_itest_new.pem"), + KeyFile: path.Join(tmpdir, "key_itest_new.pem"), + ClientCertFile: path.Join(tmpdir, "cert_itest_new.pem"), + ClientKeyFile: path.Join(tmpdir, "key_itest_new.pem"), + Logger: logger, + } + + // old rootCA certs + cli, cerr = integration.NewClient(t, clientv3.Config{ + Endpoints: []string{cluster.Members[0].GRPCURL()}, + DialTimeout: time.Second, + DialOptions: []grpc.DialOption{grpc.WithBlock()}, + TLS: cc, + }) + + if cli != nil { + cli.Close() + } + + if cerr != nil { + t.Fatalf("expected TLS handshake success, got %v", cerr) + } + + // new rootCA certs + cc, err = newClientTlsinfo.ClientConfig() + if err != nil { + t.Fatal(err) + } + + cli, cerr = integration.NewClient(t, clientv3.Config{ + Endpoints: []string{cluster.Members[0].GRPCURL()}, + DialTimeout: time.Second, + DialOptions: []grpc.DialOption{grpc.WithBlock()}, + TLS: cc, + }) + + if cli != nil { + cli.Close() + } + + if cerr != nil { + t.Fatalf("expected TLS handshake success, got %v", cerr) + } +} diff --git a/tests/integration/util_test.go b/tests/integration/util_test.go index 35b0d711ff43..bb796c55ab04 100644 --- a/tests/integration/util_test.go +++ b/tests/integration/util_test.go @@ -23,21 +23,21 @@ import ( ) // copyTLSFiles clones certs files to dst directory. -func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) { - ci := transport.TLSInfo{ +func copyTLSFiles(ti *transport.TLSInfo, dst string) (*transport.TLSInfo, error) { + ci := &transport.TLSInfo{ KeyFile: filepath.Join(dst, "server-key.pem"), CertFile: filepath.Join(dst, "server.pem"), TrustedCAFile: filepath.Join(dst, "etcd-root-ca.pem"), ClientCertAuth: ti.ClientCertAuth, } if err := copyFile(ti.KeyFile, ci.KeyFile); err != nil { - return transport.TLSInfo{}, err + return nil, err } if err := copyFile(ti.CertFile, ci.CertFile); err != nil { - return transport.TLSInfo{}, err + return nil, err } if err := copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil { - return transport.TLSInfo{}, err + return nil, err } return ci, nil } diff --git a/tests/integration/v3_grpc_test.go b/tests/integration/v3_grpc_test.go index a19500bd884b..55bef2613cba 100644 --- a/tests/integration/v3_grpc_test.go +++ b/tests/integration/v3_grpc_test.go @@ -1634,12 +1634,12 @@ func TestTLSReloadAtomicReplace(t *testing.T) { certsDirExp := t.TempDir() - cloneFunc := func() transport.TLSInfo { - tlsInfo, terr := copyTLSFiles(integration.TestTLSInfo, certsDir) + cloneFunc := func() *transport.TLSInfo { + tlsInfo, terr := copyTLSFiles(integration.TestTLSInfo.Clone(), certsDir) if terr != nil { t.Fatal(terr) } - if _, err := copyTLSFiles(integration.TestTLSInfoExpired, certsDirExp); err != nil { + if _, err := copyTLSFiles(integration.TestTLSInfoExpired.Clone(), certsDirExp); err != nil { t.Fatal(err) } return tlsInfo @@ -1676,20 +1676,20 @@ func TestTLSReloadAtomicReplace(t *testing.T) { func TestTLSReloadCopy(t *testing.T) { certsDir := t.TempDir() - cloneFunc := func() transport.TLSInfo { - tlsInfo, terr := copyTLSFiles(integration.TestTLSInfo, certsDir) + cloneFunc := func() *transport.TLSInfo { + tlsInfo, terr := copyTLSFiles(integration.TestTLSInfo.Clone(), certsDir) if terr != nil { t.Fatal(terr) } return tlsInfo } replaceFunc := func() { - if _, err := copyTLSFiles(integration.TestTLSInfoExpired, certsDir); err != nil { + if _, err := copyTLSFiles(integration.TestTLSInfoExpired.Clone(), certsDir); err != nil { t.Fatal(err) } } revertFunc := func() { - if _, err := copyTLSFiles(integration.TestTLSInfo, certsDir); err != nil { + if _, err := copyTLSFiles(integration.TestTLSInfo.Clone(), certsDir); err != nil { t.Fatal(err) } } @@ -1702,20 +1702,20 @@ func TestTLSReloadCopy(t *testing.T) { func TestTLSReloadCopyIPOnly(t *testing.T) { certsDir := t.TempDir() - cloneFunc := func() transport.TLSInfo { - tlsInfo, terr := copyTLSFiles(integration.TestTLSInfoIP, certsDir) + cloneFunc := func() *transport.TLSInfo { + tlsInfo, terr := copyTLSFiles(integration.TestTLSInfoIP.Clone(), certsDir) if terr != nil { t.Fatal(terr) } return tlsInfo } replaceFunc := func() { - if _, err := copyTLSFiles(integration.TestTLSInfoExpiredIP, certsDir); err != nil { + if _, err := copyTLSFiles(integration.TestTLSInfoExpiredIP.Clone(), certsDir); err != nil { t.Fatal(err) } } revertFunc := func() { - if _, err := copyTLSFiles(integration.TestTLSInfoIP, certsDir); err != nil { + if _, err := copyTLSFiles(integration.TestTLSInfoIP.Clone(), certsDir); err != nil { t.Fatal(err) } } @@ -1724,7 +1724,7 @@ func TestTLSReloadCopyIPOnly(t *testing.T) { func testTLSReload( t *testing.T, - cloneFunc func() transport.TLSInfo, + cloneFunc func() *transport.TLSInfo, replaceFunc func(), revertFunc func(), useIP bool) { @@ -1736,8 +1736,8 @@ func testTLSReload( // 2. start cluster with valid certs clus := integration.NewCluster(t, &integration.ClusterConfig{ Size: 1, - PeerTLS: &tlsInfo, - ClientTLS: &tlsInfo, + PeerTLS: tlsInfo, + ClientTLS: tlsInfo, UseIP: useIP, }) defer clus.Terminate(t) diff --git a/tests/integration/v3_tls_test.go b/tests/integration/v3_tls_test.go index 793d3d5a046b..5966f7262dc7 100644 --- a/tests/integration/v3_tls_test.go +++ b/tests/integration/v3_tls_test.go @@ -20,7 +20,7 @@ import ( "testing" "time" - "go.etcd.io/etcd/client/v3" + clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/tests/v3/framework/integration" "google.golang.org/grpc" ) @@ -41,14 +41,14 @@ func testTLSCipherSuites(t *testing.T, valid bool) { tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, } - srvTLS, cliTLS := integration.TestTLSInfo, integration.TestTLSInfo + srvTLS, cliTLS := integration.TestTLSInfo.Clone(), integration.TestTLSInfo.Clone() if valid { srvTLS.CipherSuites, cliTLS.CipherSuites = cipherSuites, cipherSuites } else { srvTLS.CipherSuites, cliTLS.CipherSuites = cipherSuites[:2], cipherSuites[2:] } - clus := integration.NewCluster(t, &integration.ClusterConfig{Size: 1, ClientTLS: &srvTLS}) + clus := integration.NewCluster(t, &integration.ClusterConfig{Size: 1, ClientTLS: srvTLS}) defer clus.Terminate(t) cc, err := cliTLS.ClientConfig() diff --git a/tools/etcd-dump-metrics/metrics.go b/tools/etcd-dump-metrics/metrics.go index 643dc5fe1543..b867af82ec2e 100644 --- a/tools/etcd-dump-metrics/metrics.go +++ b/tools/etcd-dump-metrics/metrics.go @@ -28,7 +28,7 @@ import ( ) func fetchMetrics(ep string) (lines []string, err error) { - tr, err := transport.NewTimeoutTransport(transport.TLSInfo{}, time.Second, time.Second, time.Second) + tr, err := transport.NewTimeoutTransport(&transport.TLSInfo{}, time.Second, time.Second, time.Second) if err != nil { return nil, err }