Skip to content

Commit

Permalink
client: support Root CA rotation on server side
Browse files Browse the repository at this point in the history
  • Loading branch information
yishuT committed Aug 20, 2022
1 parent 1851316 commit 1754e54
Show file tree
Hide file tree
Showing 38 changed files with 817 additions and 234 deletions.
2 changes: 1 addition & 1 deletion client/pkg/transport/keepalive_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestNewKeepAliveListener(t *testing.T) {
}
tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
tlscfg, err := tlsInfo.ServerConfig()
tlscfg, err := tlsInfo.ReloadableServerConfig()
if err != nil {
t.Fatalf("unexpected serverConfig error: %v", err)
}
Expand Down
251 changes: 178 additions & 73 deletions client/pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"

"go.etcd.io/etcd/client/pkg/v3/fileutil"
Expand All @@ -38,6 +40,10 @@ import (
"go.uber.org/zap"
)

const (
defaultRootCAReloadDuration = 5 * time.Minute
)

// NewListener creates a new listner.
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
return newListener(addr, scheme, WithTLSInfo(tlsinfo))
Expand Down Expand Up @@ -185,17 +191,51 @@ type TLSInfo struct {
// EmptyCN indicates that the cert must have empty CN.
// If true, ClientConfig() will return an error for a cert with non empty CN.
EmptyCN bool

tlsConfig atomic.Value // *tls.Config
refreshOnce sync.Once
RefreshDuration time.Duration
EnableRootCAReload bool
refreshDone chan struct{}
}

func (info TLSInfo) String() string {
func (info *TLSInfo) String() string {
return fmt.Sprintf("cert = %s, key = %s, client-cert=%s, client-key=%s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.ClientCertFile, info.ClientKeyFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
}

func (info TLSInfo) Empty() bool {
func (info *TLSInfo) Empty() bool {
return info.CertFile == "" && info.KeyFile == ""
}

func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertValidity uint, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) {
func (info *TLSInfo) Clone() *TLSInfo {
return &TLSInfo{
CertFile: info.CertFile,
KeyFile: info.KeyFile,
ClientCertFile: info.ClientCertFile,
ClientKeyFile: info.ClientKeyFile,
TrustedCAFile: info.TrustedCAFile,
ClientCertAuth: info.ClientCertAuth,
CRLFile: info.CRLFile,
InsecureSkipVerify: info.InsecureSkipVerify,
SkipClientSANVerify: info.SkipClientSANVerify,
ServerName: info.ServerName,
HandshakeFailure: info.HandshakeFailure,
CipherSuites: info.CipherSuites,
selfCert: info.selfCert,
parseFunc: info.parseFunc,
AllowedCN: info.AllowedCN,
AllowedHostname: info.AllowedHostname,
Logger: info.Logger,
EmptyCN: info.EmptyCN,
RefreshDuration: info.RefreshDuration,
EnableRootCAReload: info.EnableRootCAReload,
}
}

func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertValidity uint, additionalUsages ...x509.ExtKeyUsage) (info *TLSInfo, err error) {
if info == nil {
info = &TLSInfo{}
}
info.Logger = lg
if selfSignedCertValidity == 0 {
err = fmt.Errorf("selfSignedCertValidity is invalid,it should be greater than 0")
Expand Down Expand Up @@ -334,6 +374,87 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali
return SelfCert(lg, dirpath, hosts, selfSignedCertValidity)
}

func (info *TLSInfo) startRefresh() {
info.refreshOnce.Do(
func() {
info.loadServerTlsConfig()
if info.EnableRootCAReload {
if info.RefreshDuration == 0 {
info.RefreshDuration = defaultRootCAReloadDuration
}
info.refreshDone = make(chan struct{})
go info.tlsConfigRefreshLoop()
}
},
)
}

func (info *TLSInfo) loadServerTlsConfig() {
if info.Logger != nil {
info.Logger.Info("tls config reload from files")
}
cfg, err := info.serverConfig()
if err == nil {
info.tlsConfig.Store(cfg)
} else {
if info.Logger != nil {
info.Logger.Error("reload tls config error:", zap.Error(err))
}
}
}

func (info *TLSInfo) tlsConfigRefreshLoop() {
ticker := time.NewTicker(info.RefreshDuration)
defer ticker.Stop()
for {
select {
case <-ticker.C:
info.loadServerTlsConfig()
case <-info.refreshDone:
return
}
}
}

func (info *TLSInfo) getClientCertificate() (*tls.Certificate, error) {
certFile, keyFile := info.CertFile, info.KeyFile
if info.ClientCertFile != "" {
certFile, keyFile = info.ClientCertFile, info.ClientKeyFile
}
return info.getCertificates(certFile, keyFile)
}

func (info *TLSInfo) getServerCertificates() (*tls.Certificate, error) {
return info.getCertificates(info.CertFile, info.KeyFile)
}

func (info *TLSInfo) getCertificates(certFile, keyFile string) (*tls.Certificate, error) {
cert, err := tlsutil.NewCert(certFile, keyFile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
"failed to find cert files",
zap.String("cert-file", certFile),
zap.String("key-file", keyFile),
zap.Error(err),
)
}
} else if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create peer certificate",
zap.String("cert-file", certFile),
zap.String("key-file", keyFile),
zap.Error(err),
)
}
}
if err != nil {
return nil, err
}
return cert, err
}

// baseConfig is called on initial TLS handshake start.
//
// Previously,
Expand All @@ -354,7 +475,7 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali
// handshake, in order to trigger (*tls.Config).GetCertificate and populate
// rest of the certificates on every new TLS connection, even when client
// SNI is empty (e.g. cert only includes IPs).
func (info TLSInfo) baseConfig() (*tls.Config, error) {
func (info *TLSInfo) baseConfig() (*tls.Config, error) {
if info.KeyFile == "" || info.CertFile == "" {
return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
}
Expand Down Expand Up @@ -415,82 +536,15 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
return errors.New("client certificate authentication failed")
}
}

// this only reloads certs when there's a client request
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
"failed to find peer cert files",
zap.String("cert-file", info.CertFile),
zap.String("key-file", info.KeyFile),
zap.Error(err),
)
}
} else if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create peer certificate",
zap.String("cert-file", info.CertFile),
zap.String("key-file", info.KeyFile),
zap.Error(err),
)
}
}
return cert, err
}
cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (cert *tls.Certificate, err error) {
certfile, keyfile := info.CertFile, info.KeyFile
if info.ClientCertFile != "" {
certfile, keyfile = info.ClientCertFile, info.ClientKeyFile
}
cert, err = tlsutil.NewCert(certfile, keyfile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
"failed to find client cert files",
zap.String("cert-file", certfile),
zap.String("key-file", keyfile),
zap.Error(err),
)
}
} else if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create client certificate",
zap.String("cert-file", certfile),
zap.String("key-file", keyfile),
zap.Error(err),
)
}
}
return cert, err
}
return cfg, nil
}

// cafiles returns a list of CA file paths.
func (info TLSInfo) cafiles() []string {
cs := make([]string, 0)
if info.TrustedCAFile != "" {
cs = append(cs, info.TrustedCAFile)
}
return cs
}

// ServerConfig generates a tls.Config object for use by an HTTP server.
func (info TLSInfo) ServerConfig() (*tls.Config, error) {
func (info *TLSInfo) serverConfig() (*tls.Config, error) {
cfg, err := info.baseConfig()
if err != nil {
return nil, err
}

if info.Logger == nil {
info.Logger = zap.NewNop()
}

cfg.ClientAuth = tls.NoClientCert
if info.TrustedCAFile != "" || info.ClientCertAuth {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
Expand All @@ -515,11 +569,44 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) {
// setting Max TLS version to TLS 1.2 for go 1.13
cfg.MaxVersion = tls.VersionTLS12

certs, err := info.getServerCertificates()
if err != nil {
return nil, err
}
cfg.Certificates = []tls.Certificate{*certs}
return cfg, nil
}

// cafiles returns a list of CA file paths.
func (info *TLSInfo) cafiles() []string {
cs := make([]string, 0)
if info.TrustedCAFile != "" {
cs = append(cs, info.TrustedCAFile)
}
return cs
}

// ReloadableServerConfig generates a tls.Config object for use by an HTTP server.
func (info *TLSInfo) ReloadableServerConfig() (*tls.Config, error) {
info.startRefresh()
return &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
cfg, ok := info.tlsConfig.Load().(*tls.Config)
if !ok {
return nil, errors.New("server tls configuration not ready")
}
return cfg.Clone(), nil
},
// Needed to tell go http server to serve http2
NextProtos: []string{"h2"},
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return info.getClientCertificate()
},
}, nil
}

// ClientConfig generates a tls.Config object for use by an HTTP client.
func (info TLSInfo) ClientConfig() (*tls.Config, error) {
func (info *TLSInfo) ClientConfig() (*tls.Config, error) {
var cfg *tls.Config
var err error

Expand Down Expand Up @@ -574,9 +661,27 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
// setting Max TLS version to TLS 1.2 for go 1.13
cfg.MaxVersion = tls.VersionTLS12

cert, err := info.getClientCertificate()
if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"cannot create client certificate",
zap.Error(err),
)
}
} else {
cfg.Certificates = []tls.Certificate{*cert}
}

return cfg, nil
}

func (info *TLSInfo) Close() {
if info.refreshDone != nil {
close(info.refreshDone)
}
}

// IsClosedConnError returns true if the error is from closing listener, cmux.
// copied from golang.org/x/net/http2/http2.go
func IsClosedConnError(err error) bool {
Expand Down
Loading

0 comments on commit 1754e54

Please sign in to comment.