diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index 9dac85ade86..4c846a245e8 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -89,7 +89,7 @@ func (info TLSInfo) Empty() bool { return info.CertFile == "" && info.KeyFile == "" } -func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) { +func SelfCert(dirpath string, hosts []string, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) { if err = os.MkdirAll(dirpath, 0700); err != nil { return } @@ -118,7 +118,7 @@ func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) { NotAfter: time.Now().Add(365 * (24 * time.Hour)), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + ExtKeyUsage: append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...), BasicConstraintsValid: true, } diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index 6cc44a118f9..421ba124230 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -22,14 +22,20 @@ import ( "os" "testing" "time" + "crypto/x509" + "net" ) -func createSelfCert() (*TLSInfo, func(), error) { +func createSelfCert(hosts ...string) (*TLSInfo, func(), error) { + return createSelfCertEx("127.0.0.1") +} + +func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) { d, terr := ioutil.TempDir("", "etcd-test-tls-") if terr != nil { return nil, nil, terr } - info, err := SelfCert(d, []string{"127.0.0.1"}) + info, err := SelfCert(d, []string{host + ":0"}, additionalUsages...) if err != nil { return nil, nil, err } @@ -74,6 +80,100 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { } } +// TestNewListenerTLSInfoClientVerify tests that if client IP address mismatches +// with specified address in its certificate the connection is rejected +func TestNewListenerTLSInfoClientVerify(t *testing.T) { + tests := []struct { + goodClientHost bool + acceptExpected bool + }{ + {true, true}, + {false, false}, + } + for _, test := range tests { + testNewListenerTLSInfoClientCheck(t, test.goodClientHost, test.acceptExpected) + } +} + +func testNewListenerTLSInfoClientCheck(t *testing.T, goodClientHost, acceptExpected bool) { + tlsInfo, del, err := createSelfCert() + if err != nil { + t.Fatalf("unable to create cert: %v", err) + } + defer del() + + host := "127.0.0.222" + if goodClientHost { + host = "127.0.0.1" + } + clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth) + if err != nil { + t.Fatalf("unable to create cert: %v", err) + } + defer del2() + + tlsInfo.CAFile = clientTLSInfo.CertFile + + rootCAs := x509.NewCertPool() + loaded, err := ioutil.ReadFile(tlsInfo.CertFile) + if err != nil { + t.Fatalf("unexpected missing certfile: %v", err) + } + rootCAs.AppendCertsFromPEM(loaded) + + clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile) + if err != nil { + t.Fatalf("unable to create peer cert: %v", err) + } + + tlsConfig := &tls.Config{} + tlsConfig.InsecureSkipVerify = false + tlsConfig.Certificates = []tls.Certificate{clientCert} + tlsConfig.RootCAs = rootCAs + + ln, err := NewListener("127.0.0.1:0", "https", tlsInfo) + if err != nil { + t.Fatalf("unexpected NewListener error: %v", err) + } + defer ln.Close() + + tr := &http.Transport{TLSClientConfig: tlsConfig} + cli := &http.Client{Transport: tr} + chClientErr := make(chan error) + go func() { + _, err := cli.Get("https://" + ln.Addr().String()) + chClientErr <- err + }() + + chAcceptErr := make(chan error) + chAcceptConn := make(chan net.Conn) + go func() { + conn, err := ln.Accept() + if err != nil { + chAcceptErr <- err + } else { + chAcceptConn <- conn + } + }() + + select { + case <-chClientErr: + if acceptExpected { + t.Errorf("accepted for good client address: goodClientHost=%t", goodClientHost) + } + case acceptErr := <-chAcceptErr: + t.Fatalf("unexpected Accept error: %v", acceptErr) + case conn := <-chAcceptConn: + defer conn.Close() + if _, ok := conn.(*tls.Conn); !ok { + t.Errorf("failed to accept *tls.Conn") + } + if !acceptExpected { + t.Errorf("accepted for bad client address: goodClientHost=%t", goodClientHost) + } + } +} + func TestNewListenerTLSEmptyInfo(t *testing.T) { _, err := NewListener("127.0.0.1:0", "https", nil) if err == nil {