Skip to content

Commit

Permalink
support ssl/tls
Browse files Browse the repository at this point in the history
  • Loading branch information
Connor1996 committed Nov 26, 2017
1 parent 237f977 commit 35a5326
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 49 deletions.
15 changes: 12 additions & 3 deletions cmd/pd-ctl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ import (
)

var (
url string
detach bool
version bool
url string
detach bool
version bool
tlsCAPath string
tlsCertPath string
tlsKeyPath string
)

func init() {
flag.StringVarP(&url, "pd", "u", "http://127.0.0.1:2379", "The pd address")
flag.BoolVarP(&detach, "detach", "d", false, "Run pdctl without readline")
flag.BoolVarP(&version, "version", "V", false, "print version information and exit")
flag.StringVar(&tlsCAPath, "ca", "", "path of file that contains list of trusted SSL CAs.")
flag.StringVar(&tlsCertPath, "cert", "", "path of file that contains X509 certificate in PEM format.")
flag.StringVar(&tlsKeyPath, "key", "", "path of file that contains X509 key in PEM format.")
}

func main() {
Expand Down Expand Up @@ -115,6 +121,9 @@ func loop() {
}
args := strings.Split(strings.TrimSpace(line), " ")
args = append(args, "-u", url)
args = append(args, "--ca", tlsCAPath)
args = append(args, "--cert", tlsCertPath)
args = append(args, "--key", tlsKeyPath)
pdctl.Start(args)
}
}
5 changes: 4 additions & 1 deletion cmd/pd-tso-bench/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ var (
concurrency = flag.Int("C", 1000, "concurrency")
sleep = flag.Duration("sleep", time.Millisecond, "sleep time after a request, used to adjust pressure")
interval = flag.Duration("interval", time.Second, "interval to output the statistics")
tlsCAPath = flag.String("ca", "", "path of file that contains list of trusted SSL CAs.")
tlsCertPath = flag.String("cert", "", "path of file that contains X509 certificate in PEM format..")
tlsKeyPath = flag.String("key", "", "path of file that contains X509 key in PEM format.")
wg sync.WaitGroup
)

func main() {
flag.Parse()
pdCli, err := pd.NewClient([]string{*pdAddrs})
pdCli, err := pd.NewClient([]string{*pdAddrs}, *tlsCAPath, *tlsCertPath, *tlsKeyPath)
if err != nil {
log.Fatal(err)
}
Expand Down
14 changes: 14 additions & 0 deletions conf/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ initial-cluster-state = "new"
lease = 3
tso-save-interval = "3s"

# Path of file that contains list of trusted SSL CAs.
tls-ca-path = ""
# Path of file that contains X509 certificate in PEM format.
tls-cert-path = ""
# Path of file that contains X509 key in PEM format.
tls-key-path = ""
# enable client certificate auth, if true following two settings shouldn't be empty
client-cert-auth = false
# Path of file that contains X509 certificate in PEM format for client auth.
# tls-client-cert-path = ""
# Path of file that contains X509 key in PEM format for client auth.
# tls-client-key-path = ""

[log]
level = "info"

Expand Down Expand Up @@ -68,3 +81,4 @@ max-replicas = 3
# For example, ["zone", "rack"] means that we should place replicas to
# different zones first, then to different racks if we don't have enough zones.
location-labels = []

53 changes: 50 additions & 3 deletions pd-client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
package pd

import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/url"
"strings"
"sync"
"time"
Expand All @@ -25,6 +29,7 @@ import (
"github.com/pingcap/kvproto/pkg/pdpb"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// Client is a PD (Placement Driver) client.
Expand Down Expand Up @@ -93,10 +98,14 @@ type client struct {
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc

tlsCAPath string
tlsCertPath string
tlsKeyPath string
}

// NewClient creates a PD client.
func NewClient(pdAddrs []string) (Client, error) {
func NewClient(pdAddrs []string, tlsCAPath, tlsCertPath, tlsKeyPath string) (Client, error) {
log.Infof("[pd] create pd client with endpoints %v", pdAddrs)
ctx, cancel := context.WithCancel(context.Background())
c := &client{
Expand All @@ -106,6 +115,9 @@ func NewClient(pdAddrs []string) (Client, error) {
checkLeaderCh: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
tlsCAPath: tlsCAPath,
tlsCertPath: tlsCertPath,
tlsKeyPath: tlsKeyPath,
}
c.connMu.clientConns = make(map[string]*grpc.ClientConn)

Expand Down Expand Up @@ -213,11 +225,46 @@ func (c *client) getOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) {
return conn, nil
}

cc, err := grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) // TODO: Support HTTPS.
opt := grpc.WithInsecure()
if len(c.tlsCAPath) != 0 {

certificates := []tls.Certificate{}
if len(c.tlsCertPath) != 0 && len(c.tlsKeyPath) != 0 {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(c.tlsCertPath, c.tlsKeyPath)
if err != nil {
return nil, errors.Errorf("could not load client key pair: %s", err)
}
certificates = append(certificates, certificate)
}

// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(c.tlsCAPath)
if err != nil {
return nil, errors.Errorf("could not read ca certificate: %s", err)
}

// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}

creds := credentials.NewTLS(&tls.Config{
Certificates: certificates,
RootCAs: certPool,
})

opt = grpc.WithTransportCredentials(creds)
}
u, err := url.Parse(addr)
if err != nil {
return nil, errors.Trace(err)
}
cc, err := grpc.Dial(u.Host, opt)
if err != nil {
return nil, errors.Trace(err)
}

c.connMu.Lock()
defer c.connMu.Unlock()
if old, ok := c.connMu.clientConns[addr]; ok {
Expand Down
68 changes: 37 additions & 31 deletions pdctl/command/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ package command

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
Expand All @@ -23,20 +25,52 @@ import (
"net/url"
"os"

log "github.com/Sirupsen/logrus"
"github.com/juju/errors"
"github.com/pingcap/pd/pd-client"
"github.com/spf13/cobra"
)

var (
pdClient pd.Client
dailClient = &http.Client{}

pingPrefix = "pd/ping"
errInvalidAddr = errors.New("Invalid pd address, Cannot get connect to it")
)

// InitHttpsClient creates https client with ca file
func InitHttpsClient(tlsCAPath, tlsCertPath, tlsKeyPath string) error {
certificates := []tls.Certificate{}
if len(tlsCertPath) != 0 && len(tlsKeyPath) != 0 {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyPath)
if err != nil {
return errors.Errorf("could not load client key pair: %s", err)
}
certificates = append(certificates, certificate)
}

// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(tlsCAPath)
if err != nil {
return errors.Errorf("could not read ca certificate: %s", err)
}

// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return errors.New("failed to append ca certs")
}

tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: certificates,
RootCAs: certPool,
},
}
dailClient = &http.Client{Transport: tr}

return nil
}

func getRequest(cmd *cobra.Command, prefix string, method string, bodyType string, body io.Reader) (*http.Request, error) {
if method == "" {
method = http.MethodGet
Expand Down Expand Up @@ -82,34 +116,6 @@ func genResponseError(r *http.Response) error {
return errors.Errorf("[%d] %s", r.StatusCode, res)
}

// InitPDClient initialize pd client from cmd
func InitPDClient(cmd *cobra.Command) error {
addr, err := cmd.Flags().GetString("pd")
if err != nil {
return err
}
log.SetOutput(ioutil.Discard)
if pdClient != nil {
return nil
}
err = validPDAddr(addr)
if err != nil {
return err
}
pdClient, err = pd.NewClient([]string{addr})
if err != nil {
return err
}
return nil
}

func getClient() (pd.Client, error) {
if pdClient == nil {
return nil, errors.New("Must initialized pdClient firstly")
}
return pdClient, nil
}

func getAddressFromCmd(cmd *cobra.Command, prefix string) string {
p, err := cmd.Flags().GetString("pd")
if err != nil {
Expand Down
22 changes: 15 additions & 7 deletions pdctl/ctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ package pdctl

import (
"fmt"
"os"

"github.com/pingcap/pd/pdctl/command"
"github.com/spf13/cobra"
)

// CommandFlags are flags that used in all Commands
type CommandFlags struct {
URL string
URL string
tlsCAPath string
tlsCertPath string
tlsKeyPath string
}

var (
Expand All @@ -36,6 +38,9 @@ var (

func init() {
rootCmd.PersistentFlags().StringVarP(&commandFlags.URL, "pd", "u", "http://127.0.0.1:2379", "pd address")
rootCmd.Flags().StringVar(&commandFlags.tlsCAPath, "ca", "", "path of file that contains list of trusted SSL CAs.")
rootCmd.Flags().StringVar(&commandFlags.tlsCertPath, "cert", "", "path of file that contains X509 certificate in PEM format.")
rootCmd.Flags().StringVar(&commandFlags.tlsKeyPath, "key", "", "path of file that contains X509 key in PEM format.")
rootCmd.AddCommand(
command.NewConfigCommand(),
command.NewRegionCommand(),
Expand All @@ -60,12 +65,15 @@ func Start(args []string) {
rootCmd.SetArgs(args)
rootCmd.SilenceErrors = true
rootCmd.ParseFlags(args)
err := command.InitPDClient(rootCmd)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
rootCmd.SetUsageTemplate(command.UsageTemplate)

if len(commandFlags.tlsCAPath) != 0 {
if err := command.InitHttpsClient(commandFlags.tlsCAPath, commandFlags.tlsCertPath, commandFlags.tlsKeyPath); err != nil {
fmt.Println(err)
return
}
}

if err := rootCmd.Execute(); err != nil {
fmt.Println(rootCmd.UsageString())
}
Expand Down
16 changes: 13 additions & 3 deletions server/api/redirector.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package api

import (
"crypto/tls"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -67,17 +68,26 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http
return
}

newCustomReverseProxies(urls).ServeHTTP(w, r)
tlsConfig, err := h.s.GetTLSConfig()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
newCustomReverseProxies(urls, tlsConfig).ServeHTTP(w, r)
}

type customReverseProxies struct {
urls []url.URL
client *http.Client
}

func newCustomReverseProxies(urls []url.URL) *customReverseProxies {
func newCustomReverseProxies(urls []url.URL, tlsConfig *tls.Config) *customReverseProxies {
p := &customReverseProxies{
client: &http.Client{},
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}

p.urls = append(p.urls, urls...)
Expand Down
Loading

0 comments on commit 35a5326

Please sign in to comment.