Skip to content

Commit

Permalink
Enable reloading CA without a restart
Browse files Browse the repository at this point in the history
* Add two options to server: "client-root-ca-reload" and "peer-root-ca-reload".
  By default, these options are set to false. Whenever the options are enabled,
  the server will dynamically load CA keys & certs.
* Provide implementation for "GetConfigForClient". This will allow server to
  load CA files on each TLS handshake.
* Provide implementation for "VerifyConnection". This will clients (for peer connection)
  to load CA files per request.

Note: this patch implements CA reloading without performance optimization.
Optimization could be done in the future. Potential optimization is
to avoid loading CA on each request. We could implement a background
routine to periodically loading CA files instead.

Signed-off-by: Hongbin Lu <hongbinlu@microsoft.com>
  • Loading branch information
hongbin committed Aug 30, 2023
1 parent 10498ce commit 588a667
Show file tree
Hide file tree
Showing 15 changed files with 388 additions and 12 deletions.
76 changes: 71 additions & 5 deletions client/pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ type TLSInfo struct {
// EmptyCN indicates that the cert must have empty CN.
// If true, ClientConfig() will return an error for a cert with non empty CN.
EmptyCN bool

// EnableRootCAReload indicates whether to reload root CA dynamically.
EnableRootCAReload bool
}

func (info TLSInfo) String() string {
Expand Down Expand Up @@ -435,10 +438,21 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
}
}

// this only reloads certs when there's a client request
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
if info.EnableRootCAReload {
cfg.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
cfg, err := info.ServerConfig()
if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create tls config",
zap.Error(err),
)
}
}
return cfg, err
}

cert, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
Expand All @@ -458,7 +472,33 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
)
}
}
return cert, err
cfg.Certificates = []tls.Certificate{*cert}
} else {
// this only reloads certs when there's a client request
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
"failed to find peer cert files",
zap.String("cert-file", info.CertFile),
zap.String("key-file", info.KeyFile),
zap.Error(err),
)
}
} else if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create peer certificate",
zap.String("cert-file", info.CertFile),
zap.String("key-file", info.KeyFile),
zap.Error(err),
)
}
}
return cert, err
}
}
cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (cert *tls.Certificate, err error) {
certfile, keyfile := info.CertFile, info.KeyFile
Expand Down Expand Up @@ -557,6 +597,32 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {

if info.selfCert {
cfg.InsecureSkipVerify = true
} else if info.EnableRootCAReload {
if len(cs) == 0 {
return nil, fmt.Errorf("cannot enable root CA reloading without a trusted CA file")
}

// Set InsecureSkipVerify to skip the default validation we are replacing.
// This will not disable VerifyConnection.
cfg.InsecureSkipVerify = true

cfg.VerifyConnection = func(connState tls.ConnectionState) error {
// dynamically load CA from file
rootCAs, err := tlsutil.NewCertPool(cs)
if err != nil {
return err
}
opts := x509.VerifyOptions{
DNSName: connState.ServerName,
Intermediates: x509.NewCertPool(),
Roots: rootCAs,
}
for _, cert := range connState.PeerCertificates[1:] {
opts.Intermediates.AddCert(cert)
}
_, err = connState.PeerCertificates[0].Verify(opts)
return err
}
}

if info.EmptyCN {
Expand Down
2 changes: 2 additions & 0 deletions server/etcdmain/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,13 @@ func newConfig() *config {
fs.StringVar(&cfg.ec.ClientTLSInfo.CRLFile, "client-crl-file", "", "Path to the client certificate revocation list file.")
fs.StringVar(&cfg.ec.ClientTLSInfo.AllowedHostname, "client-cert-allowed-hostname", "", "Allowed TLS hostname for client cert authentication.")
fs.StringVar(&cfg.ec.ClientTLSInfo.TrustedCAFile, "trusted-ca-file", "", "Path to the client server TLS trusted CA cert file.")
fs.BoolVar(&cfg.ec.ClientTLSInfo.EnableRootCAReload, "client-root-ca-reload", false, "Enable client server TLS root CA dynamic reload to support root CA rotation")
fs.BoolVar(&cfg.ec.ClientAutoTLS, "auto-tls", false, "Client TLS using generated certificates")
fs.StringVar(&cfg.ec.PeerTLSInfo.CertFile, "peer-cert-file", "", "Path to the peer server TLS cert file.")
fs.StringVar(&cfg.ec.PeerTLSInfo.KeyFile, "peer-key-file", "", "Path to the peer server TLS key file.")
fs.StringVar(&cfg.ec.PeerTLSInfo.ClientCertFile, "peer-client-cert-file", "", "Path to an explicit peer client TLS cert file otherwise peer cert file will be used when client auth is required.")
fs.StringVar(&cfg.ec.PeerTLSInfo.ClientKeyFile, "peer-client-key-file", "", "Path to an explicit peer client TLS key file otherwise peer key file will be used when client auth is required.")
fs.BoolVar(&cfg.ec.PeerTLSInfo.EnableRootCAReload, "peer-root-ca-reload", false, "Enable peer client TLS root CA dynamic reload to support root CA rotation")
fs.BoolVar(&cfg.ec.PeerTLSInfo.ClientCertAuth, "peer-client-cert-auth", false, "Enable peer client cert authentication.")
fs.StringVar(&cfg.ec.PeerTLSInfo.TrustedCAFile, "peer-trusted-ca-file", "", "Path to the peer server TLS trusted CA file.")
fs.BoolVar(&cfg.ec.PeerAutoTLS, "peer-auto-tls", false, "Peer TLS using generated certificates")
Expand Down
4 changes: 4 additions & 0 deletions server/etcdmain/help.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ Security:
Allowed TLS hostname for client cert authentication.
--trusted-ca-file ''
Path to the client server TLS trusted CA cert file.
--client-root-ca-reload 'false'
Enable client server TLS root CA dynamic reload to support root CA rotation.
--auto-tls 'false'
Client TLS using generated certificates.
--peer-cert-file ''
Expand All @@ -201,6 +203,8 @@ Security:
Path to an explicit peer client TLS key file otherwise peer key file will be used when client auth is required.
--peer-trusted-ca-file ''
Path to the peer server TLS trusted CA file.
--peer-root-ca-reload 'false'
Enable peer client TLS root CA dynamic reload to support root CA rotation.
--peer-cert-allowed-cn ''
Required CN for client certs connecting to the peer endpoint.
--peer-cert-allowed-hostname ''
Expand Down
5 changes: 5 additions & 0 deletions tests/common/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package common

import (
"go.etcd.io/etcd/client/pkg/v3/fileutil"
"go.etcd.io/etcd/client/pkg/v3/transport"
"go.etcd.io/etcd/tests/v3/framework"
"go.etcd.io/etcd/tests/v3/framework/config"
"go.etcd.io/etcd/tests/v3/framework/e2e"
Expand Down Expand Up @@ -81,3 +82,7 @@ func WithAuth(userName, password string) config.ClientOption {
func WithEndpoints(endpoints []string) config.ClientOption {
return e2e.WithEndpoints(endpoints)
}

func WithTLSInfo(tlsInfo *transport.TLSInfo) config.ClientOption {
return e2e.WithTLSInfo(tlsInfo)
}
5 changes: 5 additions & 0 deletions tests/common/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package common

import (
"go.etcd.io/etcd/client/pkg/v3/transport"
"go.etcd.io/etcd/tests/v3/framework"
"go.etcd.io/etcd/tests/v3/framework/config"
"go.etcd.io/etcd/tests/v3/framework/integration"
Expand Down Expand Up @@ -59,3 +60,7 @@ func WithAuth(userName, password string) config.ClientOption {
func WithEndpoints(endpoints []string) config.ClientOption {
return integration.WithEndpoints(endpoints)
}

func WithTLSInfo(tlsInfo *transport.TLSInfo) config.ClientOption {
return integration.WithTLSInfo(tlsInfo)
}
206 changes: 206 additions & 0 deletions tests/common/root_ca_rotation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package common

import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"os"
"path"
"testing"
"time"

"go.etcd.io/etcd/client/pkg/v3/transport"
"go.etcd.io/etcd/tests/v3/framework/config"
"go.etcd.io/etcd/tests/v3/framework/testutils"
)

func newSerialNumber(t *testing.T) *big.Int {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
t.Fail()
}
return serialNumber
}

func createRootCertificateAuthority(rootCaPath string, oldPem []byte, t *testing.T) (*x509.Certificate, []byte, *ecdsa.PrivateKey) {
serialNumber := newSerialNumber(t)
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}

tmpl := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"etcd"}},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * (24 * time.Hour)),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment,
BasicConstraintsValid: true,
IsCA: true,
}

caBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
if err != nil {
t.Fatal(err)
}

ca, err := x509.ParseCertificate(caBytes)
if err != nil {
t.Fatal(err)
}
caBlocks := [][]byte{caBytes}
if len(oldPem) > 0 {
caBlocks = append(caBlocks, oldPem)
}
marshalCerts(caBlocks, rootCaPath, t)
return ca, caBytes, priv
}

func generateCerts(privKey *ecdsa.PrivateKey, rootCA *x509.Certificate, dir, suffix string, t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}
serialNumber := newSerialNumber(t)
tmpl := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"etcd"}},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * (24 * time.Hour)),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment,
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"localhost"},
}
caBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, rootCA, &priv.PublicKey, privKey)
if err != nil {
t.Fatal(err)
}
marshalCerts([][]byte{caBytes}, path.Join(dir, fmt.Sprintf("cert%s.pem", suffix)), t)
marshalKeys(priv, path.Join(dir, fmt.Sprintf("key%s.pem", suffix)), t)
}

func marshalCerts(caBytes [][]byte, certPath string, t *testing.T) {
var caPerm bytes.Buffer
for _, caBlock := range caBytes {
err := pem.Encode(&caPerm, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBlock,
})
if err != nil {
t.Fatal(err)
}
}
ioutil.WriteFile(certPath, caPerm.Bytes(), 0644)
}

func marshalKeys(privKey *ecdsa.PrivateKey, keyPath string, t *testing.T) {
privBytes, err := x509.MarshalECPrivateKey(privKey)
if err != nil {
t.Fatal(err)
}

var keyPerm bytes.Buffer
err = pem.Encode(&keyPerm, &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: privBytes,
})
if err != nil {
t.Fatal(err)
}
ioutil.WriteFile(keyPath, keyPerm.Bytes(), 0644)
}

func TestRootCARotation(t *testing.T) {
testRunner.BeforeTest(t)

t.Run("server CA rotation", func(t *testing.T) {
tmpdir, err := ioutil.TempDir(os.TempDir(), "tlsdir-integration-reload")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
rootCAPath := path.Join(tmpdir, "ca-cert.pem")
rootCA, caBytes, privKey := createRootCertificateAuthority(rootCAPath, []byte{}, t)
generateCerts(privKey, rootCA, tmpdir, "_itest_old", t)

tlsInfo := &transport.TLSInfo{
TrustedCAFile: rootCAPath,
CertFile: path.Join(tmpdir, "cert_itest_old.pem"),
KeyFile: path.Join(tmpdir, "key_itest_old.pem"),
ClientCertFile: path.Join(tmpdir, "cert_itest_old.pem"),
ClientKeyFile: path.Join(tmpdir, "key_itest_old.pem"),
EnableRootCAReload: true,
}
clusConfig := config.ClusterConfig{ClusterSize: 1, ClientTLS: config.ManualTLS, ClientTLSInfo: tlsInfo}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clus := testRunner.NewCluster(ctx, t, config.WithClusterConfig(clusConfig))
defer clus.Close()

cc, cerr := clus.Client(WithTLSInfo(tlsInfo))
if cerr != nil {
t.Fatalf("expected TLS handshake success, got %v", cerr)
}
testutils.ExecuteUntil(ctx, t, func() {
key := "foo"
_, err := cc.Get(ctx, key, config.GetOptions{})
if err != nil {
t.Fatalf("Unexpeted result, err: %s", err)
}
})

// regenerate rootCA and sign new certs
rootCA, _, privKey = createRootCertificateAuthority(rootCAPath, caBytes, t)
generateCerts(privKey, rootCA, tmpdir, "_itest_new", t)

// old rootCA certs
cc, cerr = clus.Client(WithTLSInfo(tlsInfo))
if cerr != nil {
t.Fatalf("expected TLS handshake success, got %v", cerr)
}
testutils.ExecuteUntil(ctx, t, func() {
key := "foo"
_, err := cc.Get(ctx, key, config.GetOptions{})
if err != nil {
t.Fatalf("Unexpeted result, err: %s", err)
}
})

// new rootCA certs
newClientTlsinfo := &transport.TLSInfo{
TrustedCAFile: rootCAPath,
CertFile: path.Join(tmpdir, "cert_itest_new.pem"),
KeyFile: path.Join(tmpdir, "key_itest_new.pem"),
ClientCertFile: path.Join(tmpdir, "cert_itest_new.pem"),
ClientKeyFile: path.Join(tmpdir, "key_itest_new.pem"),
}

cc, cerr = clus.Client(WithTLSInfo(newClientTlsinfo))
if cerr != nil {
t.Fatalf("expected TLS handshake success, got %v", cerr)
}
testutils.ExecuteUntil(ctx, t, func() {
key := "foo"
_, err := cc.Get(ctx, key, config.GetOptions{})
if err != nil {
t.Fatalf("Unexpeted result, err: %s", err)
}
})
})

// TODO(hongbin): added test for peer CA rotation
}
5 changes: 5 additions & 0 deletions tests/common/unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package common

import (
"go.etcd.io/etcd/client/pkg/v3/transport"
"go.etcd.io/etcd/tests/v3/framework"
"go.etcd.io/etcd/tests/v3/framework/config"
)
Expand All @@ -40,3 +41,7 @@ func WithAuth(userName, password string) config.ClientOption {
func WithEndpoints(endpoints []string) config.ClientOption {
return func(any) {}
}

func WithTLSInfo(tlsInfo *transport.TLSInfo) config.ClientOption {
return func(any) {}
}
2 changes: 2 additions & 0 deletions tests/e2e/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ func tlsInfo(t testing.TB, cfg e2e.ClientConfig) (*transport.TLSInfo, error) {
return nil, fmt.Errorf("failed to generate cert: %s", err)
}
return &tls, nil
} else if cfg.TLSInfo != nil {
return cfg.TLSInfo, nil
} else {
return &integration.TestTLSInfo, nil
}
Expand Down
Loading

0 comments on commit 588a667

Please sign in to comment.