Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support ssl/tls #868

Merged
merged 17 commits into from
Dec 1, 2017
Merged
15 changes: 12 additions & 3 deletions cmd/pd-ctl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ import (
)

var (
url string
detach bool
version bool
url string
detach bool
version bool
CAPath string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why public?

CertPath string
KeyPath string
)

func init() {
flag.StringVarP(&url, "pd", "u", "http://127.0.0.1:2379", "The pd address")
flag.BoolVarP(&detach, "detach", "d", false, "Run pdctl without readline")
flag.BoolVarP(&version, "version", "V", false, "print version information and exit")
flag.StringVar(&CAPath, "cacert", "", "path of file that contains list of trusted SSL CAs.")
flag.StringVar(&CertPath, "cert", "", "path of file that contains X509 certificate in PEM format.")
flag.StringVar(&KeyPath, "key", "", "path of file that contains X509 key in PEM format.")
}

func main() {
Expand Down Expand Up @@ -115,6 +121,9 @@ func loop() {
}
args := strings.Split(strings.TrimSpace(line), " ")
args = append(args, "-u", url)
args = append(args, "--cacert", CAPath)
args = append(args, "--cert", CertPath)
args = append(args, "--key", KeyPath)
pdctl.Start(args)
}
}
10 changes: 9 additions & 1 deletion cmd/pd-tso-bench/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,20 @@ var (
concurrency = flag.Int("C", 1000, "concurrency")
sleep = flag.Duration("sleep", time.Millisecond, "sleep time after a request, used to adjust pressure")
interval = flag.Duration("interval", time.Second, "interval to output the statistics")
CAPath = flag.String("cacert", "", "path of file that contains list of trusted SSL CAs.")
CertPath = flag.String("cert", "", "path of file that contains X509 certificate in PEM format..")
KeyPath = flag.String("key", "", "path of file that contains X509 key in PEM format.")
wg sync.WaitGroup
)

func main() {
flag.Parse()
pdCli, err := pd.NewClient([]string{*pdAddrs})

pdCli, err := pd.NewClient([]string{*pdAddrs}, pd.SecurityOption{
CAPath: *CAPath,
CertPath: *CertPath,
KeyPath: *KeyPath,
})
if err != nil {
log.Fatal(err)
}
Expand Down
9 changes: 9 additions & 0 deletions conf/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ initial-cluster-state = "new"
lease = 3
tso-save-interval = "3s"

[security]
# Path of file that contains list of trusted SSL CAs. if set, following four settings shouldn't be empty
cacert-path = ""
# Path of file that contains X509 certificate in PEM format.
cert-path = ""
# Path of file that contains X509 key in PEM format.
key-path = ""

[log]
level = "info"

Expand Down Expand Up @@ -68,3 +76,4 @@ max-replicas = 3
# For example, ["zone", "rack"] means that we should place replicas to
# different zones first, then to different racks if we don't have enough zones.
location-labels = []

56 changes: 53 additions & 3 deletions pd-client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
package pd

import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/url"
"strings"
"sync"
"time"
Expand All @@ -25,6 +29,7 @@ import (
"github.com/pingcap/kvproto/pkg/pdpb"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// Client is a PD (Placement Driver) client.
Expand Down Expand Up @@ -93,10 +98,19 @@ type client struct {
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc

security SecurityOption
}

// SecurityOption records options about tls
type SecurityOption struct {
CAPath string
CertPath string
KeyPath string
}

// NewClient creates a PD client.
func NewClient(pdAddrs []string) (Client, error) {
func NewClient(pdAddrs []string, security SecurityOption) (Client, error) {
log.Infof("[pd] create pd client with endpoints %v", pdAddrs)
ctx, cancel := context.WithCancel(context.Background())
c := &client{
Expand All @@ -106,6 +120,7 @@ func NewClient(pdAddrs []string) (Client, error) {
checkLeaderCh: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
security: security,
}
c.connMu.clientConns = make(map[string]*grpc.ClientConn)

Expand Down Expand Up @@ -213,11 +228,46 @@ func (c *client) getOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) {
return conn, nil
}

cc, err := grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) // TODO: Support HTTPS.
opt := grpc.WithInsecure()
if len(c.security.CAPath) != 0 {

certificates := []tls.Certificate{}
if len(c.security.CertPath) != 0 && len(c.security.KeyPath) != 0 {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(c.security.CertPath, c.security.KeyPath)
if err != nil {
return nil, errors.Errorf("could not load client key pair: %s", err)
}
certificates = append(certificates, certificate)
}

// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(c.security.CAPath)
if err != nil {
return nil, errors.Errorf("could not read ca certificate: %s", err)
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
return nil, errors.New("failed to append ca certs")
}

creds := credentials.NewTLS(&tls.Config{
Certificates: certificates,
RootCAs: certPool,
})

opt = grpc.WithTransportCredentials(creds)
}
u, err := url.Parse(addr)
if err != nil {
return nil, errors.Trace(err)
}
cc, err := grpc.Dial(u.Host, opt)
if err != nil {
return nil, errors.Trace(err)
}

c.connMu.Lock()
defer c.connMu.Unlock()
if old, ok := c.connMu.clientConns[addr]; ok {
Expand Down
2 changes: 1 addition & 1 deletion pd-client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (s *testClientSuite) SetUpSuite(c *C) {
bootstrapServer(c, newHeader(s.srv), s.grpcPDClient)

var err error
s.client, err = NewClient(s.srv.GetEndpoints())
s.client, err = NewClient(s.srv.GetEndpoints(), SecurityOption{})
c.Assert(err, IsNil)
s.regionHeartbeat, err = s.grpcPDClient.RegionHeartbeat(context.Background())
c.Assert(err, IsNil)
Expand Down
8 changes: 4 additions & 4 deletions pd-client/leader_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (s *testLeaderChangeSuite) TestLeaderConfigChange(c *C) {
svrs, endpoints, closeFunc := s.prepareClusterN(c, 3)
defer closeFunc()

cli, err := NewClient(endpoints)
cli, err := NewClient(endpoints, SecurityOption{})
c.Assert(err, IsNil)
defer cli.Close()

Expand Down Expand Up @@ -109,7 +109,7 @@ func (s *testLeaderChangeSuite) TestMemberList(c *C) {
_, endpoints, closeFunc := s.prepareClusterN(c, 2)
defer closeFunc()

cli, err := NewClient(endpoints[:1])
cli, err := NewClient(endpoints[:1], SecurityOption{})
c.Assert(err, IsNil)
cli.Close()

Expand All @@ -122,7 +122,7 @@ func (s *testLeaderChangeSuite) TestLeaderChange(c *C) {
svrs, endpoints, closeFunc := s.prepareClusterN(c, 3)
defer closeFunc()

cli, err := NewClient(endpoints)
cli, err := NewClient(endpoints, SecurityOption{})
c.Assert(err, IsNil)
defer cli.Close()

Expand Down Expand Up @@ -163,7 +163,7 @@ func (s *testLeaderChangeSuite) TestLeaderTransfer(c *C) {
servers, endpoints, closeFunc := s.prepareClusterN(c, 2)
defer closeFunc()

cli, err := NewClient(endpoints)
cli, err := NewClient(endpoints, SecurityOption{})
c.Assert(err, IsNil)
defer cli.Close()

Expand Down
68 changes: 37 additions & 31 deletions pdctl/command/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ package command

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
Expand All @@ -23,20 +25,52 @@ import (
"net/url"
"os"

log "github.com/Sirupsen/logrus"
"github.com/juju/errors"
"github.com/pingcap/pd/pd-client"
"github.com/spf13/cobra"
)

var (
pdClient pd.Client
dailClient = &http.Client{}

pingPrefix = "pd/ping"
errInvalidAddr = errors.New("Invalid pd address, Cannot get connect to it")
)

// InitHTTPSClient creates https client with ca file
func InitHTTPSClient(CAPath, CertPath, KeyPath string) error {
certificates := []tls.Certificate{}
if len(CertPath) != 0 && len(KeyPath) != 0 {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(CertPath, KeyPath)
if err != nil {
return errors.Errorf("could not load client key pair: %s", err)
}
certificates = append(certificates, certificate)
}

// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(CAPath)
if err != nil {
return errors.Errorf("could not read ca certificate: %s", err)
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
return errors.New("failed to append ca certs")
}

tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: certificates,
RootCAs: certPool,
},
}
dailClient = &http.Client{Transport: tr}

return nil
}

func getRequest(cmd *cobra.Command, prefix string, method string, bodyType string, body io.Reader) (*http.Request, error) {
if method == "" {
method = http.MethodGet
Expand Down Expand Up @@ -82,34 +116,6 @@ func genResponseError(r *http.Response) error {
return errors.Errorf("[%d] %s", r.StatusCode, res)
}

// InitPDClient initialize pd client from cmd
func InitPDClient(cmd *cobra.Command) error {
addr, err := cmd.Flags().GetString("pd")
if err != nil {
return err
}
log.SetOutput(ioutil.Discard)
if pdClient != nil {
return nil
}
err = validPDAddr(addr)
if err != nil {
return err
}
pdClient, err = pd.NewClient([]string{addr})
if err != nil {
return err
}
return nil
}

func getClient() (pd.Client, error) {
if pdClient == nil {
return nil, errors.New("Must initialized pdClient firstly")
}
return pdClient, nil
}

func getAddressFromCmd(cmd *cobra.Command, prefix string) string {
p, err := cmd.Flags().GetString("pd")
if err != nil {
Expand Down
22 changes: 15 additions & 7 deletions pdctl/ctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ package pdctl

import (
"fmt"
"os"

"github.com/pingcap/pd/pdctl/command"
"github.com/spf13/cobra"
)

// CommandFlags are flags that used in all Commands
type CommandFlags struct {
URL string
URL string
CAPath string
CertPath string
KeyPath string
}

var (
Expand All @@ -36,6 +38,9 @@ var (

func init() {
rootCmd.PersistentFlags().StringVarP(&commandFlags.URL, "pd", "u", "http://127.0.0.1:2379", "pd address")
rootCmd.Flags().StringVar(&commandFlags.CAPath, "cacert", "", "path of file that contains list of trusted SSL CAs.")
rootCmd.Flags().StringVar(&commandFlags.CertPath, "cert", "", "path of file that contains X509 certificate in PEM format.")
rootCmd.Flags().StringVar(&commandFlags.KeyPath, "key", "", "path of file that contains X509 key in PEM format.")
rootCmd.AddCommand(
command.NewConfigCommand(),
command.NewRegionCommand(),
Expand All @@ -60,12 +65,15 @@ func Start(args []string) {
rootCmd.SetArgs(args)
rootCmd.SilenceErrors = true
rootCmd.ParseFlags(args)
err := command.InitPDClient(rootCmd)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
rootCmd.SetUsageTemplate(command.UsageTemplate)

if len(commandFlags.CAPath) != 0 {
if err := command.InitHTTPSClient(commandFlags.CAPath, commandFlags.CertPath, commandFlags.KeyPath); err != nil {
fmt.Println(err)
return
}
}

if err := rootCmd.Execute(); err != nil {
fmt.Println(rootCmd.UsageString())
}
Expand Down
7 changes: 5 additions & 2 deletions pkg/etcdutil/etcdutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package etcdutil

import (
"context"
"crypto/tls"
"net/http"
"time"

Expand Down Expand Up @@ -43,7 +44,7 @@ const (

// CheckClusterID checks Etcd's cluster ID, returns an error if mismatch.
// This function will never block even quorum is not satisfied.
func CheckClusterID(localClusterID types.ID, um types.URLsMap) error {
func CheckClusterID(localClusterID types.ID, um types.URLsMap, tlsConfig *tls.Config) error {
if len(um) == 0 {
return nil
}
Expand All @@ -54,7 +55,9 @@ func CheckClusterID(localClusterID types.ID, um types.URLsMap) error {
}

for _, u := range peerURLs {
trp := &http.Transport{}
trp := &http.Transport{
TLSClientConfig: tlsConfig,
}
remoteCluster, gerr := etcdserver.GetClusterFromRemotePeers([]string{u}, trp)
trp.CloseIdleConnections()
if gerr != nil {
Expand Down
Loading