diff --git a/arbiter/server.go b/arbiter/server.go index fb8d1b294..8abbcffa6 100644 --- a/arbiter/server.go +++ b/arbiter/server.go @@ -78,7 +78,7 @@ func NewServer(cfg *Config) (srv *Server, err error) { up := cfg.Up down := cfg.Down - srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port) + srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port, nil) if err != nil { return nil, errors.Trace(err) } diff --git a/arbiter/server_test.go b/arbiter/server_test.go index 8b68a5550..a10f502da 100644 --- a/arbiter/server_test.go +++ b/arbiter/server_test.go @@ -15,6 +15,7 @@ package arbiter import ( "context" + "crypto/tls" "database/sql" "fmt" "sync" @@ -55,7 +56,7 @@ func (l *dummyLoader) Close() { type testNewServerSuite struct { db *sql.DB dbMock sqlmock.Sqlmock - origCreateDB func(string, string, string, int) (*sql.DB, error) + origCreateDB func(string, string, string, int, *tls.Config) (*sql.DB, error) origNewReader func(*reader.Config) (*reader.Reader, error) origNewLoader func(*sql.DB, ...loader.Option) (loader.Loader, error) } @@ -71,7 +72,7 @@ func (s *testNewServerSuite) SetUpTest(c *C) { s.dbMock = mock s.origCreateDB = createDB - createDB = func(user string, password string, host string, port int) (*sql.DB, error) { + createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) { return s.db, nil } @@ -105,7 +106,7 @@ func (s *testNewServerSuite) TestRejectInvalidAddr(c *C) { } func (s *testNewServerSuite) TestStopIfFailedtoConnectDownStream(c *C) { - createDB = func(user string, password string, host string, port int) (*sql.DB, error) { + createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) { return nil, fmt.Errorf("Can't create db") } diff --git a/binlogctl/config.go b/binlogctl/config.go index 413a0a4be..cbd2dc3d4 100644 --- a/binlogctl/config.go +++ b/binlogctl/config.go @@ -65,20 +65,20 @@ const ( // Config holds the configuration of drainer type Config struct { - *flag.FlagSet - - Command string `toml:"cmd" json:"cmd"` - NodeID string `toml:"node-id" json:"node-id"` - DataDir string `toml:"data-dir" json:"data-dir"` - TimeZone string `toml:"time-zone" json:"time-zone"` - EtcdURLs string `toml:"pd-urls" json:"pd-urls"` - SSLCA string `toml:"ssl-ca" json:"ssl-ca"` - SSLCert string `toml:"ssl-cert" json:"ssl-cert"` - SSLKey string `toml:"ssl-key" json:"ssl-key"` - State string `toml:"state" json:"state"` - ShowOfflineNodes bool `toml:"state" json:"show-offline-nodes"` - Text string `toml:"text" json:"text"` - tls *tls.Config + *flag.FlagSet `toml:"-" json:"-"` + + Command string `toml:"cmd" json:"cmd"` + NodeID string `toml:"node-id" json:"node-id"` + DataDir string `toml:"data-dir" json:"data-dir"` + TimeZone string `toml:"time-zone" json:"time-zone"` + EtcdURLs string `toml:"pd-urls" json:"pd-urls"` + SSLCA string `toml:"ssl-ca" json:"ssl-ca"` + SSLCert string `toml:"ssl-cert" json:"ssl-cert"` + SSLKey string `toml:"ssl-key" json:"ssl-key"` + State string `toml:"state" json:"state"` + ShowOfflineNodes bool `toml:"state" json:"show-offline-nodes"` + Text string `toml:"text" json:"text"` + TLS *tls.Config `toml:"-" json:"-"` printVersion bool } @@ -134,7 +134,7 @@ func (cfg *Config) Parse(args []string) error { SSLCert: cfg.SSLCert, SSLKey: cfg.SSLKey, } - cfg.tls, err = sCfg.ToTLSConfig() + cfg.TLS, err = sCfg.ToTLSConfig() if err != nil { return errors.Errorf("tls config error %v", err) } diff --git a/binlogctl/nodes.go b/binlogctl/nodes.go index 5555e3e99..30a1cbca4 100644 --- a/binlogctl/nodes.go +++ b/binlogctl/nodes.go @@ -15,6 +15,7 @@ package binlogctl import ( "context" + "crypto/tls" "fmt" "net/http" "time" @@ -34,8 +35,8 @@ var ( ) // QueryNodesByKind returns specified nodes, like pumps/drainers -func QueryNodesByKind(urls string, kind string, showOffline bool) error { - registry, err := createRegistryFuc(urls) +func QueryNodesByKind(urls string, kind string, showOffline bool, tlsConfig *tls.Config) error { + registry, err := createRegistryFuc(urls, tlsConfig) if err != nil { return errors.Trace(err) } @@ -56,12 +57,12 @@ func QueryNodesByKind(urls string, kind string, showOffline bool) error { } // UpdateNodeState update pump or drainer's state. -func UpdateNodeState(urls, kind, nodeID, state string) error { +func UpdateNodeState(urls, kind, nodeID, state string, tlsConfig *tls.Config) error { /* node's state can be online, pausing, paused, closing and offline. if the state is one of them, will update the node's state saved in etcd directly. */ - registry, err := createRegistryFuc(urls) + registry, err := createRegistryFuc(urls, tlsConfig) if err != nil { return errors.Trace(err) } @@ -81,12 +82,12 @@ func UpdateNodeState(urls, kind, nodeID, state string) error { } // createRegistry returns an ectd registry -func createRegistry(urls string) (*node.EtcdRegistry, error) { +func createRegistry(urls string, tlsConfig *tls.Config) (*node.EtcdRegistry, error) { ectdEndpoints, err := flags.ParseHostPortAddr(urls) if err != nil { return nil, errors.Trace(err) } - cli, err := newEtcdClientFromCfgFunc(ectdEndpoints, etcdDialTimeout, node.DefaultRootPath, nil) + cli, err := newEtcdClientFromCfgFunc(ectdEndpoints, etcdDialTimeout, node.DefaultRootPath, tlsConfig) if err != nil { return nil, errors.Trace(err) } @@ -95,8 +96,8 @@ func createRegistry(urls string) (*node.EtcdRegistry, error) { } // ApplyAction applies action on pump or drainer -func ApplyAction(urls, kind, nodeID string, action string) error { - registry, err := createRegistryFuc(urls) +func ApplyAction(urls, kind, nodeID string, action string, tlsConfig *tls.Config) error { + registry, err := createRegistryFuc(urls, tlsConfig) if err != nil { return errors.Trace(err) } @@ -106,14 +107,18 @@ func ApplyAction(urls, kind, nodeID string, action string) error { return errors.Trace(err) } - var client http.Client - url := fmt.Sprintf("http://%s/state/%s/%s", n.Addr, n.NodeID, action) + schema := "http" + if tlsConfig != nil { + schema = "https" + } + + url := fmt.Sprintf("%s://%s/state/%s/%s", schema, n.Addr, n.NodeID, action) log.Debug("send put http request", zap.String("url", url)) req, err := http.NewRequest("PUT", url, nil) if err != nil { return errors.Trace(err) } - _, err = client.Do(req) + _, err = getClient(tlsConfig).Do(req) if err == nil { log.Info("Apply action on node success", zap.String("action", action), zap.String("NodeID", n.NodeID)) return nil @@ -121,3 +126,15 @@ func ApplyAction(urls, kind, nodeID string, action string) error { return errors.Trace(err) } + +func getClient(tlsConfig *tls.Config) *http.Client { + if tlsConfig == nil { + return &http.Client{} + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } +} diff --git a/binlogctl/nodes_test.go b/binlogctl/nodes_test.go index ef851f033..0f518e9ae 100644 --- a/binlogctl/nodes_test.go +++ b/binlogctl/nodes_test.go @@ -48,7 +48,7 @@ type testNodesSuite struct{} func (s *testNodesSuite) SetUpTest(c *C) { newEtcdClientFromCfgFunc = newFakeEtcdClientFromCfg createRegistryFuc = createMockRegistry - _, err := createMockRegistry("127.0.0.1:2379") + _, err := createMockRegistry("127.0.0.1:2379", nil) c.Assert(err, IsNil) } @@ -63,11 +63,11 @@ func (s *testNodesSuite) TestApplyAction(c *C) { registerPumpForTest(c, "test", url) - err := ApplyAction("127.0.0.1:2379", "pumps", "test2", PausePump) + err := ApplyAction("127.0.0.1:2379", "pumps", "test2", PausePump, nil) c.Assert(errors.IsNotFound(err), IsTrue) // TODO: handle log information and add check - err = ApplyAction("127.0.0.1:2379", "pumps", "test", PausePump) + err = ApplyAction("127.0.0.1:2379", "pumps", "test", PausePump, nil) c.Assert(err, IsNil) } @@ -75,17 +75,17 @@ func (s *testNodesSuite) TestQueryNodesByKind(c *C) { registerPumpForTest(c, "test", "127.0.0.1:8255") // TODO: handle log information and add check - err := QueryNodesByKind("127.0.0.1:2379", "pumps", false) + err := QueryNodesByKind("127.0.0.1:2379", "pumps", false, nil) c.Assert(err, IsNil) } func (s *testNodesSuite) TestUpdateNodeState(c *C) { registerPumpForTest(c, "test", "127.0.0.1:8255") - err := UpdateNodeState("127.0.0.1:2379", "pumps", "test2", node.Paused) + err := UpdateNodeState("127.0.0.1:2379", "pumps", "test2", node.Paused, nil) c.Assert(err, ErrorMatches, ".*not found.*") - err = UpdateNodeState("127.0.0.1:2379", "pumps", "test", node.Paused) + err = UpdateNodeState("127.0.0.1:2379", "pumps", "test", node.Paused, nil) c.Assert(err, IsNil) // check node's state is changed to paused @@ -104,7 +104,7 @@ func (s *testNodesSuite) TestUpdateNodeState(c *C) { func (s *testNodesSuite) TestCreateRegistry(c *C) { urls := "127.0.0.1:2379" - registry, err := createRegistry(urls) + registry, err := createRegistry(urls, nil) c.Assert(err, IsNil) c.Assert(registry, NotNil) @@ -131,7 +131,7 @@ func (s *testNodesSuite) TestCreateRegistry(c *C) { } -func createMockRegistry(urls string) (*node.EtcdRegistry, error) { +func createMockRegistry(urls string, _ *tls.Config) (*node.EtcdRegistry, error) { if fakeRegistry != nil { return fakeRegistry, nil } diff --git a/cmd/binlogctl/main.go b/cmd/binlogctl/main.go index 7f56473fd..b8c72d932 100644 --- a/cmd/binlogctl/main.go +++ b/cmd/binlogctl/main.go @@ -45,21 +45,21 @@ func main() { case ctl.GenerateMeta: err = ctl.GenerateMetaInfo(cfg) case ctl.QueryPumps: - err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.PumpNode, cfg.ShowOfflineNodes) + err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.PumpNode, cfg.ShowOfflineNodes, cfg.TLS) case ctl.QueryDrainers: - err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.DrainerNode, cfg.ShowOfflineNodes) + err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.DrainerNode, cfg.ShowOfflineNodes, cfg.TLS) case ctl.UpdatePump: - err = ctl.UpdateNodeState(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, cfg.State) + err = ctl.UpdateNodeState(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, cfg.State, cfg.TLS) case ctl.UpdateDrainer: - err = ctl.UpdateNodeState(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, cfg.State) + err = ctl.UpdateNodeState(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, cfg.State, cfg.TLS) case ctl.PausePump: - err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, pause) + err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, pause, cfg.TLS) case ctl.PauseDrainer: - err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, pause) + err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, pause, cfg.TLS) case ctl.OfflinePump: - err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, close) + err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, close, cfg.TLS) case ctl.OfflineDrainer: - err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, close) + err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, close, cfg.TLS) case ctl.Encrypt: if len(cfg.Text) == 0 { err = errors.New("need to specify the text to be encrypt") diff --git a/cmd/drainer/drainer.toml b/cmd/drainer/drainer.toml index b52512233..c0da90a3e 100644 --- a/cmd/drainer/drainer.toml +++ b/cmd/drainer/drainer.toml @@ -108,6 +108,16 @@ port = 3306 # when setting SyncPartialColumn drainer will allow the downstream schema # having more or less column numbers and relax sql mode by removing STRICT_TRANS_TABLES. # sync-mode = 1 +# +# Uncomment this part if you need TLS to connecting downstream MySQL/TiDB. +# You can only specified only `ssl-ca` if there is no client certificate and don't need server to authenticate client. +# [syncer.to.security] +# Path of file that contains list of trusted SSL CAs. +# ssl-ca = "/path/to/ca.pem" +# Path of file that contains X509 certificate in PEM format. +# ssl-cert = "/path/to/drainer.pem" +# Path of file that contains X509 key in PEM format. +# ssl-key = "/path/to/drainer-key.pem" [syncer.to.checkpoint] # only support mysql or tidb now, you can uncomment this to control where the checkpoint is saved. diff --git a/cmd/pump/main.go b/cmd/pump/main.go index 1493a244b..ecbccd0c8 100644 --- a/cmd/pump/main.go +++ b/cmd/pump/main.go @@ -45,6 +45,7 @@ func main() { log.Fatal("Failed to initialize log", zap.Error(err)) } version.PrintVersionInfo("Pump") + log.Info("start pump...", zap.Reflect("config", cfg)) p, err := pump.NewServer(cfg) if err != nil { diff --git a/drainer/collector.go b/drainer/collector.go index bca6c161d..778264f39 100644 --- a/drainer/collector.go +++ b/drainer/collector.go @@ -14,6 +14,7 @@ package drainer import ( + "crypto/tls" "fmt" "net/http" "strings" @@ -49,6 +50,7 @@ type notifyResult struct { // Collector collects binlog from all pump, and send binlog to syncer. type Collector struct { clusterID uint64 + tls *tls.Config interval time.Duration reg *node.EtcdRegistry tiStore kv.Storage @@ -106,6 +108,7 @@ func NewCollector(cfg *Config, clusterID uint64, s *Syncer, cpt checkpoint.Check c := &Collector{ clusterID: clusterID, + tls: cfg.tls, interval: time.Duration(cfg.DetectInterval) * time.Second, reg: node.NewEtcdRegistry(cli, cfg.EtcdTimeout), pumps: make(map[string]*Pump), @@ -308,7 +311,7 @@ func (c *Collector) handlePumpStatusUpdate(ctx context.Context, n *node.Status) } commitTS := c.merger.GetLatestTS() - p := NewPump(n.NodeID, n.Addr, c.clusterID, commitTS, c.errCh) + p := NewPump(n.NodeID, n.Addr, c.tls, c.clusterID, commitTS, c.errCh) c.pumps[n.NodeID] = p c.merger.AddSource(MergeSource{ ID: n.NodeID, diff --git a/drainer/config.go b/drainer/config.go index 50f4b93d6..4b473726e 100644 --- a/drainer/config.go +++ b/drainer/config.go @@ -216,6 +216,13 @@ func (cfg *Config) Parse(args []string) error { return errors.Errorf("tls config %+v error %v", cfg.Security, err) } + if cfg.SyncerCfg != nil && cfg.SyncerCfg.To != nil { + cfg.SyncerCfg.To.TLS, err = cfg.SyncerCfg.To.Security.ToTLSConfig() + if err != nil { + return errors.Errorf("tls config %+v error %v", cfg.SyncerCfg.To.Security, err) + } + } + if err = cfg.adjustConfig(); err != nil { return errors.Trace(err) } diff --git a/drainer/pump.go b/drainer/pump.go index cde8d413b..c3c0b41e5 100644 --- a/drainer/pump.go +++ b/drainer/pump.go @@ -14,6 +14,7 @@ package drainer import ( + "crypto/tls" "strings" "sync/atomic" "time" @@ -29,6 +30,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/status" ) @@ -40,6 +42,7 @@ const ( type Pump struct { nodeID string addr string + tlsConfig *tls.Config clusterID uint64 // the latest binlog ts that pump had handled latestTS int64 @@ -56,11 +59,12 @@ type Pump struct { } // NewPump returns an instance of Pump -func NewPump(nodeID, addr string, clusterID uint64, startTs int64, errCh chan error) *Pump { +func NewPump(nodeID, addr string, tlsConfig *tls.Config, clusterID uint64, startTs int64, errCh chan error) *Pump { nodeID = pump.FormatNodeID(nodeID) return &Pump{ nodeID: nodeID, addr: addr, + tlsConfig: tlsConfig, clusterID: clusterID, latestTS: startTs, errCh: errCh, @@ -204,7 +208,14 @@ func (p *Pump) createPullBinlogsClient(ctx context.Context, last int64) error { callOpts = append(callOpts, grpc.UseCompressor(compressor)) } - conn, err := grpc.Dial(p.addr, grpc.WithInsecure(), grpc.WithDefaultCallOptions(callOpts...)) + dialOpts := []grpc.DialOption{grpc.WithDefaultCallOptions(callOpts...)} + if p.tlsConfig != nil { + dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(p.tlsConfig))) + } else { + dialOpts = append(dialOpts, grpc.WithInsecure()) + } + + conn, err := grpc.Dial(p.addr, dialOpts...) if err != nil { p.logger.Error("pump create grpc dial failed", zap.Error(err)) p.pullCli = nil diff --git a/drainer/pump_test.go b/drainer/pump_test.go index 56e56eeb7..d03656f65 100644 --- a/drainer/pump_test.go +++ b/drainer/pump_test.go @@ -61,7 +61,7 @@ func (x *mockPumpPullBinlogsClient) Recv() (*binlog.PullBinlogResp, error) { func (s *pumpSuite) TestPullBinlog(c *C) { errChan := make(chan error, 10) - p := NewPump("pump_test", "", 0, 5, errChan) + p := NewPump("pump_test", "", nil, 0, 5, errChan) p.grpcConn = &grpc.ClientConn{} binlogBytesChan := make(chan []byte, 10) p.pullCli = &mockPumpPullBinlogsClient{binlogBytesChan: binlogBytesChan} diff --git a/drainer/relay.go b/drainer/relay.go index ad0b8626c..b33452bd7 100644 --- a/drainer/relay.go +++ b/drainer/relay.go @@ -43,7 +43,7 @@ func feedByRelayLogIfNeed(cfg *Config) error { return errors.Annotate(err, "failed to create reader") } - db, err := loader.CreateDBWithSQLMode(scfg.To.User, scfg.To.Password, scfg.To.Host, scfg.To.Port, scfg.StrSQLMode) + db, err := loader.CreateDBWithSQLMode(scfg.To.User, scfg.To.Password, scfg.To.Host, scfg.To.Port, scfg.To.TLS, scfg.StrSQLMode) if err != nil { return errors.Annotate(err, "failed to create SQL db") } diff --git a/drainer/server.go b/drainer/server.go index 4549e5884..ec88157ee 100644 --- a/drainer/server.go +++ b/drainer/server.go @@ -15,7 +15,6 @@ package drainer import ( "fmt" - "net" "net/http" "net/url" "os" @@ -98,6 +97,11 @@ func NewServer(cfg *Config) (*Server, error) { return nil, err } + if cfg.tls != nil { + // TODO: avoid this magic enabling TLS for tikv client. + var _ = cfg.Security.ToTiDBSecurityConfig() + } + // get pd client and cluster ID pdCli, err := getPdClient(cfg.EtcdURLs, cfg.Security) if err != nil { @@ -284,15 +288,12 @@ func (s *Server) Start() error { } }) - // start a TCP listener - tcpURL, err := url.Parse(s.tcpAddr) - if err != nil { - return errors.Annotatef(err, "invalid listening tcp addr (%s)", s.tcpAddr) - } - tcpLis, err := net.Listen("tcp", tcpURL.Host) + // We need to manage TLS here for cmux to distinguish between HTTP and gRPC. + tcpLis, err := util.Listen("tcp", s.tcpAddr, s.cfg.tls) if err != nil { - return errors.Annotatef(err, "fail to start TCP listener on %s", tcpURL.Host) + return errors.Trace(err) } + m := cmux.New(tcpLis) grpcL := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) httpL := m.Match(cmux.HTTP1Fast()) diff --git a/drainer/sync/mysql.go b/drainer/sync/mysql.go index c957f7506..bd2de8d04 100644 --- a/drainer/sync/mysql.go +++ b/drainer/sync/mysql.go @@ -88,7 +88,11 @@ func NewMysqlSyncer( relayer relay.Relayer, info *loopbacksync.LoopBackSync, ) (*MysqlSyncer, error) { - db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, sqlMode) + if cfg.TLS != nil { + log.Info("enable TLS to connect downstream MySQL/TiDB") + } + + db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.TLS, sqlMode) if err != nil { return nil, errors.Trace(err) } @@ -104,7 +108,7 @@ func NewMysqlSyncer( if newMode != oldMode { db.Close() - db, err = createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, &newMode) + db, err = createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.TLS, &newMode) if err != nil { return nil, errors.Trace(err) } diff --git a/drainer/sync/syncer_test.go b/drainer/sync/syncer_test.go index 7449fb737..d02dd9388 100644 --- a/drainer/sync/syncer_test.go +++ b/drainer/sync/syncer_test.go @@ -13,6 +13,7 @@ package sync import ( + "crypto/tls" "database/sql" "reflect" "sync/atomic" @@ -57,7 +58,7 @@ func (s *syncerSuite) SetUpTest(c *check.C) { // create mysql syncer oldCreateDB := createDB - createDB = func(string, string, string, int, *string) (db *sql.DB, err error) { + createDB = func(string, string, string, int, *tls.Config, *string) (db *sql.DB, err error) { db, s.mysqlMock, err = sqlmock.New() return } diff --git a/drainer/sync/util.go b/drainer/sync/util.go index 8afcb7842..7c2892ee1 100644 --- a/drainer/sync/util.go +++ b/drainer/sync/util.go @@ -14,15 +14,20 @@ package sync import ( + "crypto/tls" + // mysql driver _ "github.com/go-sql-driver/mysql" + "github.com/pingcap/tidb-binlog/pkg/security" ) // DBConfig is the DB configuration. type DBConfig struct { - Host string `toml:"host" json:"host"` - User string `toml:"user" json:"user"` - Password string `toml:"password" json:"password"` + Host string `toml:"host" json:"host"` + User string `toml:"user" json:"user"` + Password string `toml:"password" json:"password"` + Security security.Config `toml:"security" json:"security"` + TLS *tls.Config `toml:"-" json:"-"` // if EncryptedPassword is not empty, Password will be ignore. EncryptedPassword string `toml:"encrypted_password" json:"encrypted_password"` SyncMode int `toml:"sync-mode" json:"sync-mode"` diff --git a/pkg/loader/example_loader_test.go b/pkg/loader/example_loader_test.go index a545bc326..0f76ad5bd 100644 --- a/pkg/loader/example_loader_test.go +++ b/pkg/loader/example_loader_test.go @@ -17,7 +17,7 @@ import "log" func Example() { // create sql.DB - db, err := CreateDB("root", "", "localhost", 4000) + db, err := CreateDB("root", "", "localhost", 4000, nil /* *tls.Config */) if err != nil { log.Fatal(err) } diff --git a/pkg/loader/util.go b/pkg/loader/util.go index 996312b12..20ef3dd0f 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -14,12 +14,16 @@ package loader import ( + "crypto/tls" gosql "database/sql" "fmt" "hash/crc32" "net/url" + "strconv" "strings" + "sync/atomic" + "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" ) @@ -77,14 +81,25 @@ func getTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, e return } +var customID int64 + // CreateDBWithSQLMode return sql.DB -func CreateDBWithSQLMode(user string, password string, host string, port int, sqlMode *string) (db *gosql.DB, err error) { +func CreateDBWithSQLMode(user string, password string, host string, port int, tlsConfig *tls.Config, sqlMode *string) (db *gosql.DB, err error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, host, port) if sqlMode != nil { // same as "set sql_mode = ''" dsn += "&sql_mode='" + url.QueryEscape(*sqlMode) + "'" } + if tlsConfig != nil { + name := "custom_" + strconv.FormatInt(atomic.AddInt64(&customID, 1), 10) + err := mysql.RegisterTLSConfig(name, tlsConfig) + if err != nil { + return nil, errors.Annotate(err, "failed to RegisterTLSConfig") + } + dsn += "&tls=" + name + } + db, err = gosql.Open("mysql", dsn) if err != nil { return nil, errors.Trace(err) @@ -93,8 +108,8 @@ func CreateDBWithSQLMode(user string, password string, host string, port int, sq } // CreateDB return sql.DB -func CreateDB(user string, password string, host string, port int) (db *gosql.DB, err error) { - return CreateDBWithSQLMode(user, password, host, port, nil) +func CreateDB(user string, password string, host string, port int, tls *tls.Config) (db *gosql.DB, err error) { + return CreateDBWithSQLMode(user, password, host, port, tls, nil) } func quoteSchema(schema string, table string) string { diff --git a/pkg/security/security.go b/pkg/security/security.go index 71089f3e9..7be3d6f1c 100644 --- a/pkg/security/security.go +++ b/pkg/security/security.go @@ -30,38 +30,52 @@ type Config struct { } // ToTLSConfig generates tls's config based on security section of the config. -func (c *Config) ToTLSConfig() (*tls.Config, error) { - var tlsConfig *tls.Config - if len(c.SSLCA) != 0 { - var certificates = make([]tls.Certificate, 0) - if len(c.SSLCert) != 0 && len(c.SSLKey) != 0 { +func (c *Config) ToTLSConfig() (tlsConfig *tls.Config, err error) { + if c.SSLCA == "" { + return + } + + // Create a certificate pool from the certificate authority + certPool := x509.NewCertPool() + var ca []byte + ca, err = ioutil.ReadFile(c.SSLCA) + if err != nil { + return nil, errors.Errorf("could not read ca certificate: %s", err) + } + + // Append the certificates from the CA + if !certPool.AppendCertsFromPEM(ca) { + return nil, errors.New("failed to append ca certs") + } + + tlsConfig = &tls.Config{ + RootCAs: certPool, + } + + if len(c.SSLCert) != 0 && len(c.SSLKey) != 0 { + getCert := func() (*tls.Certificate, error) { // Load the client certificates from disk - certificate, err := tls.LoadX509KeyPair(c.SSLCert, c.SSLKey) + cert, err := tls.LoadX509KeyPair(c.SSLCert, c.SSLKey) if err != nil { return nil, errors.Errorf("could not load client key pair: %s", err) } - certificates = append(certificates, certificate) + return &cert, nil } - // Create a certificate pool from the certificate authority - certPool := x509.NewCertPool() - ca, err := ioutil.ReadFile(c.SSLCA) - if err != nil { - return nil, errors.Errorf("could not read ca certificate: %s", err) + // pre-test cert's loading. + if _, err = getCert(); err != nil { + return } - // Append the certificates from the CA - if !certPool.AppendCertsFromPEM(ca) { - return nil, errors.New("failed to append ca certs") + tlsConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, err error) { + return getCert() } - - tlsConfig = &tls.Config{ - Certificates: certificates, - RootCAs: certPool, + tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) { + return getCert() } } - return tlsConfig, nil + return } // ToTiDBSecurityConfig generates tidb security config @@ -72,6 +86,8 @@ func (c *Config) ToTiDBSecurityConfig() config.Security { ClusterSSLKey: c.SSLKey, } + // The TiKV client(kvstore.New) we use will use this global var as the TLS config. + // TODO avoid such magic implicit change when call this func. config.GetGlobalConfig().Security = security return security } diff --git a/pkg/security/security_test.go b/pkg/security/security_test.go index 35ecb7b04..59539b6ac 100644 --- a/pkg/security/security_test.go +++ b/pkg/security/security_test.go @@ -29,6 +29,46 @@ func TestClient(t *testing.T) { TestingT(t) } +// These certs are generated with: +// +// ```sh +// # generate CA keys +// openssl ecparam -name secp224r1 -genkey -noout -out ca.key +// openssl req -x509 -new -nodes -key ca.key -days 999999 -out ca.crt -subj '/CN=localhost' +// +// # generate SSL keys +// openssl ecparam -name secp224r1 -genkey -noout -out ssl.key +// openssl req -new -key ssl.key -out ssl.csr -subj '/CN=localhost' +// openssl x509 -req -in ssl.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out ssl.crt -days 999999 +// ``` +var testCa = ` +-----BEGIN CERTIFICATE----- +MIIBBjCBtQIJAMLMVjQw2v1pMAoGCCqGSM49BAMCMBQxEjAQBgNVBAMMCWxvY2Fs +aG9zdDAgFw0xOTA0MTcxODEyNDNaGA80NzU3MDMxMzE4MTI0M1owFDESMBAGA1UE +AwwJbG9jYWxob3N0ME4wEAYHKoZIzj0CAQYFK4EEACEDOgAELwEHdmAcDtBYK9BH +72q0dKbBBqIG7MZ5+qc+LTcz0OSdhuWkWUZkNN6MqKAPuP7nSo1+21Vb8YswCgYI +KoZIzj0EAwIDQAAwPQIcbNvV16rOOzwotH65cJY6cCdf0h3IODjlWMf1qAIdAIBB +Fma6g8iW5zdQPqDR9BGqugNPjtI/SMK6tfQ= +-----END CERTIFICATE----- +` +var testCert = ` +-----BEGIN CERTIFICATE----- +MIIBBTCBtAIJAP8wfS+6tJ3LMAkGByqGSM49BAEwFDESMBAGA1UEAwwJbG9jYWxo +b3N0MCAXDTE5MDQxNzE4MTI0NFoYDzQ3NTcwMzEzMTgxMjQ0WjAUMRIwEAYDVQQD +DAlsb2NhbGhvc3QwTjAQBgcqhkjOPQIBBgUrgQQAIQM6AAQJaXEnDhG2tPxD4wl1 +ycaZwqWm9JeQZFuUPgxekGwCMM22sKpYLvhdKroSBoKWwXIC6vZMWeIj/zAJBgcq +hkjOPQQBA0EAMD4CHQC05dXi9zFLjYjQGhpJNx+Nc/5vC6E7j/MU+xsTAh0A6SUn +g916djuFWv8djdDq+0NEFD9OzgPdSb8rZw== +-----END CERTIFICATE----- +` +var testKey = ` +-----BEGIN EC PRIVATE KEY----- +MGgCAQEEHCsPBVueZ3YX3yp1tn15YXj0cTKGCo1SO1EWO92gBwYFK4EEACGhPAM6 +AAQJaXEnDhG2tPxD4wl1ycaZwqWm9JeQZFuUPgxekGwCMM22sKpYLvhdKroSBoKW +wXIC6vZMWeIj/w== +-----END EC PRIVATE KEY----- +` + var _ = Suite(&testSecuritySuite{}) type testSecuritySuite struct{} @@ -62,60 +102,32 @@ func (s *testSecuritySuite) TestToTLSConfig(c *C) { SSLKey: filepath.Join(temp, "ssl.key"), } - // These certs are generated with: - // - // ```sh - // # generate CA keys - // openssl ecparam -name secp224r1 -genkey -noout -out ca.key - // openssl req -x509 -new -nodes -key ca.key -days 999999 -out ca.crt -subj '/CN=localhost' - // - // # generate SSL keys - // openssl ecparam -name secp224r1 -genkey -noout -out ssl.key - // openssl req -new -key ssl.key -out ssl.csr -subj '/CN=localhost' - // openssl x509 -req -in ssl.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out ssl.crt -days 999999 - // ``` - - err := ioutil.WriteFile(dummyConfig.SSLCA, []byte(` ------BEGIN CERTIFICATE----- -MIIBBjCBtQIJAMLMVjQw2v1pMAoGCCqGSM49BAMCMBQxEjAQBgNVBAMMCWxvY2Fs -aG9zdDAgFw0xOTA0MTcxODEyNDNaGA80NzU3MDMxMzE4MTI0M1owFDESMBAGA1UE -AwwJbG9jYWxob3N0ME4wEAYHKoZIzj0CAQYFK4EEACEDOgAELwEHdmAcDtBYK9BH -72q0dKbBBqIG7MZ5+qc+LTcz0OSdhuWkWUZkNN6MqKAPuP7nSo1+21Vb8YswCgYI -KoZIzj0EAwIDQAAwPQIcbNvV16rOOzwotH65cJY6cCdf0h3IODjlWMf1qAIdAIBB -Fma6g8iW5zdQPqDR9BGqugNPjtI/SMK6tfQ= ------END CERTIFICATE----- - `), 0644) + err := ioutil.WriteFile(dummyConfig.SSLCA, []byte(testCa), 0644) c.Assert(err, IsNil) - err = ioutil.WriteFile(dummyConfig.SSLCert, []byte(` ------BEGIN CERTIFICATE----- -MIIBBTCBtAIJAP8wfS+6tJ3LMAkGByqGSM49BAEwFDESMBAGA1UEAwwJbG9jYWxo -b3N0MCAXDTE5MDQxNzE4MTI0NFoYDzQ3NTcwMzEzMTgxMjQ0WjAUMRIwEAYDVQQD -DAlsb2NhbGhvc3QwTjAQBgcqhkjOPQIBBgUrgQQAIQM6AAQJaXEnDhG2tPxD4wl1 -ycaZwqWm9JeQZFuUPgxekGwCMM22sKpYLvhdKroSBoKWwXIC6vZMWeIj/zAJBgcq -hkjOPQQBA0EAMD4CHQC05dXi9zFLjYjQGhpJNx+Nc/5vC6E7j/MU+xsTAh0A6SUn -g916djuFWv8djdDq+0NEFD9OzgPdSb8rZw== ------END CERTIFICATE----- - `), 0644) + err = ioutil.WriteFile(dummyConfig.SSLCert, []byte(testCert), 0644) c.Assert(err, IsNil) - err = ioutil.WriteFile(dummyConfig.SSLKey, []byte(` ------BEGIN EC PRIVATE KEY----- -MGgCAQEEHCsPBVueZ3YX3yp1tn15YXj0cTKGCo1SO1EWO92gBwYFK4EEACGhPAM6 -AAQJaXEnDhG2tPxD4wl1ycaZwqWm9JeQZFuUPgxekGwCMM22sKpYLvhdKroSBoKW -wXIC6vZMWeIj/w== ------END EC PRIVATE KEY----- - `), 0600) + err = ioutil.WriteFile(dummyConfig.SSLKey, []byte(testKey), 0600) c.Assert(err, IsNil) config, err := dummyConfig.ToTLSConfig() c.Assert(err, IsNil) c.Assert(config, NotNil) c.Assert(config.RootCAs.Subjects(), HasLen, 1) - c.Assert(config.Certificates, HasLen, 1) - sslKey, ok := config.Certificates[0].PrivateKey.(*ecdsa.PrivateKey) + + cert, err := config.GetCertificate(nil) + c.Assert(err, IsNil) + sslKey, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + c.Assert(ok, IsTrue) + c.Assert(sslKey.Curve, Equals, elliptic.P224()) + + cert, err = config.GetClientCertificate(nil) + c.Assert(err, IsNil) + sslKey, ok = cert.PrivateKey.(*ecdsa.PrivateKey) c.Assert(ok, IsTrue) c.Assert(sslKey.Curve, Equals, elliptic.P224()) + } func (s *testSecuritySuite) TestEmptyTLSConfig(c *C) { @@ -148,7 +160,17 @@ func (s *testSecuritySuite) TestInvalidTLSConfig(c *C) { err = ioutil.WriteFile(dummyConfig.SSLKey, []byte("invalid key"), 0600) c.Assert(err, IsNil) + // make ca valid. + err = ioutil.WriteFile(dummyConfig.SSLCA, []byte(testCa), 0644) + c.Assert(err, IsNil) _, err = dummyConfig.ToTLSConfig() c.Assert(err, ErrorMatches, "could not load client key pair.*") + // make cert/key valid can check again. + err = ioutil.WriteFile(dummyConfig.SSLCert, []byte(testCert), 0644) + c.Assert(err, IsNil) + err = ioutil.WriteFile(dummyConfig.SSLKey, []byte(testKey), 0600) + c.Assert(err, IsNil) + _, err = dummyConfig.ToTLSConfig() + c.Assert(err, IsNil) } diff --git a/pkg/util/net.go b/pkg/util/net.go new file mode 100644 index 000000000..0607a0db5 --- /dev/null +++ b/pkg/util/net.go @@ -0,0 +1,44 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "crypto/tls" + "net" + "net/url" + + "github.com/pingcap/errors" +) + +// Listen return the listener from tls.Listen if tlsConfig is NOT Nil. +func Listen(network, addr string, tlsConfig *tls.Config) (listener net.Listener, err error) { + URL, err := url.Parse(addr) + if err != nil { + return nil, errors.Annotatef(err, "invalid listening socket addr (%s)", addr) + } + + if tlsConfig != nil { + listener, err = tls.Listen(network, URL.Host, tlsConfig) + if err != nil { + return nil, errors.Annotatef(err, "fail to start %s on %s", network, URL.Host) + } + } else { + listener, err = net.Listen(network, URL.Host) + if err != nil { + return nil, errors.Annotatef(err, "fail to start %s on %s", network, URL.Host) + } + } + + return listener, nil +} diff --git a/pkg/util/net_test.go b/pkg/util/net_test.go new file mode 100644 index 000000000..6aa151324 --- /dev/null +++ b/pkg/util/net_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "github.com/pingcap/check" +) + +type netSuite struct{} + +var _ = check.Suite(&netSuite{}) + +func (n *netSuite) TestListen(c *check.C) { + // wrong addr + _, err := Listen("unix", "://asdf:1231:123:12", nil) + c.Assert(err, check.ErrorMatches, ".*invalid .* socket addr.*") + + // unbindable addr + _, err = Listen("tcp", "http://asdf;klj:7979/12", nil) + c.Assert(err, check.ErrorMatches, ".*fail to start.*") + + // return listener + l, err := Listen("tcp", "http://localhost:17979", nil) + c.Assert(err, check.IsNil) + c.Assert(l, check.NotNil) +} diff --git a/pump/config.go b/pump/config.go index b9f29fed3..cf7f7894a 100644 --- a/pump/config.go +++ b/pump/config.go @@ -54,7 +54,7 @@ type globalConfig struct { // Config holds the configuration of pump type Config struct { - *flag.FlagSet + *flag.FlagSet `json:"-"` LogLevel string `toml:"log-level" json:"log-level"` NodeID string `toml:"node-id" json:"node-id"` ListenAddr string `toml:"addr" json:"addr"` diff --git a/pump/server.go b/pump/server.go index 1bfa328a5..43689495c 100644 --- a/pump/server.go +++ b/pump/server.go @@ -19,7 +19,6 @@ import ( "math" "net" "net/http" - "net/url" "strconv" "strings" "sync" @@ -47,7 +46,6 @@ import ( "go.uber.org/zap" "golang.org/x/net/context" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) var ( @@ -128,9 +126,6 @@ func NewServer(cfg *Config) (*Server, error) { log.Info("get clusterID success", zap.Uint64("clusterID", clusterID)) grpcOpts := []grpc.ServerOption{grpc.MaxRecvMsgSize(GlobalConfig.maxMsgSize)} - if cfg.tls != nil { - grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(cfg.tls))) - } urlv, err := flags.NewURLsValue(cfg.EtcdURLs) if err != nil { @@ -336,12 +331,11 @@ func (s *Server) startHeartbeat() { // Start runs Pump Server to serve the listening addr, and maintains heartbeat to Etcd func (s *Server) Start() error { - // start a UNIX listener var unixLis net.Listener var err error if s.unixAddr != "" { - unixLis, err = listen("unix", s.unixAddr) + unixLis, err = util.Listen("unix", s.unixAddr, s.cfg.tls) if err != nil { return errors.Trace(err) } @@ -350,7 +344,8 @@ func (s *Server) Start() error { log.Debug("init success") // start a TCP listener - tcpLis, err := listen("tcp", s.tcpAddr) + // we need to manage TLS here for cmux to distinguish between HTTP and gRPC. + tcpLis, err := util.Listen("tcp", s.tcpAddr, s.cfg.tls) if err != nil { return errors.Trace(err) } @@ -388,7 +383,12 @@ func (s *Server) Start() error { grpcL := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) httpL := m.Match(cmux.HTTP1Fast()) - go s.gs.Serve(grpcL) + + go func() { + if err := s.gs.Serve(grpcL); err != nil { + log.Error("Unexpected exit of gRPC server", zap.Error(err)) + } + }() router := mux.NewRouter() router.HandleFunc("/status", s.Status).Methods("GET") @@ -923,15 +923,3 @@ func (s *Server) waitUntilCommitTSSaved(ctx context.Context, ts int64, checkInte } } } - -func listen(network, addr string) (net.Listener, error) { - URL, err := url.Parse(addr) - if err != nil { - return nil, errors.Annotatef(err, "invalid listening socket addr (%s)", addr) - } - listener, err := net.Listen(network, URL.Host) - if err != nil { - return nil, errors.Annotatef(err, "fail to start %s on %s", network, URL.Host) - } - return listener, nil -} diff --git a/pump/server_test.go b/pump/server_test.go index 0957f3cf6..8e7d41c7b 100644 --- a/pump/server_test.go +++ b/pump/server_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io/ioutil" "math/rand" - "net" "net/http" "path" "strconv" @@ -544,28 +543,6 @@ func (s *waitCommitTSSuite) TestShouldWaitUntilTs(c *C) { } } -type listenSuite struct{} - -var _ = Suite(&listenSuite{}) - -func (s *listenSuite) TestWrongAddr(c *C) { - _, err := listen("unix", "://asdf:1231:123:12") - c.Assert(err, ErrorMatches, ".*invalid .* socket addr.*") -} - -func (s *listenSuite) TestUnbindableAddr(c *C) { - _, err := listen("tcp", "http://asdf;klj:7979/12") - c.Assert(err, ErrorMatches, ".*fail to start.*") -} - -func (s *listenSuite) TestReturnListener(c *C) { - var l net.Listener - l, err := listen("tcp", "http://localhost:17979") - c.Assert(err, IsNil) - defer l.Close() - c.Assert(l, NotNil) -} - type mockPdCli struct { pd.Client } diff --git a/reparo/syncer/mysql.go b/reparo/syncer/mysql.go index b023e905a..f318eef96 100644 --- a/reparo/syncer/mysql.go +++ b/reparo/syncer/mysql.go @@ -50,7 +50,7 @@ var ( var createDB = loader.CreateDB func newMysqlSyncer(cfg *DBConfig, worker int, batchSize int, safemode bool) (*mysqlSyncer, error) { - db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port) + db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, nil) if err != nil { return nil, errors.Trace(err) } diff --git a/reparo/syncer/mysql_test.go b/reparo/syncer/mysql_test.go index 963bc8c41..3c93d87ea 100644 --- a/reparo/syncer/mysql_test.go +++ b/reparo/syncer/mysql_test.go @@ -1,6 +1,7 @@ package syncer import ( + "crypto/tls" "database/sql" "time" @@ -24,7 +25,7 @@ func (s *testMysqlSuite) testMysqlSyncer(c *check.C, safemode bool) { ) oldCreateDB := createDB - createDB = func(string, string, string, int) (db *sql.DB, err error) { + createDB = func(string, string, string, int, *tls.Config) (db *sql.DB, err error) { db, mock, err = sqlmock.New() return } diff --git a/tests/_utils/check_status b/tests/_utils/check_status index dfcb4e070..9fde6dd58 100755 --- a/tests/_utils/check_status +++ b/tests/_utils/check_status @@ -4,11 +4,11 @@ set -eu NODE_KIND=$1 if [ $NODE_KIND == "pumps" ]; then - NODE_ID=$2 - STATE=$3 + NODE_ID=$2 + STATE=$3 else - NODE_ID="drainer" - STATE=$2 + NODE_ID="drainer" + STATE=$2 fi OUT_DIR=/tmp/tidb_binlog_test @@ -18,7 +18,10 @@ max_commit_ts_new=0 for i in {1..15} do - binlogctl -pd-urls 127.0.0.1:2379 -cmd $NODE_KIND -show-offline-nodes > $STATUS_LOG 2>&1 + binlogctl -ssl-ca $OUT_DIR/cert/ca.pem \ + -ssl-cert $OUT_DIR/cert/client.pem \ + -ssl-key $OUT_DIR/cert/client.key \ + -pd-urls https://127.0.0.1:2379 -cmd $NODE_KIND -show-offline-nodes > $STATUS_LOG 2>&1 cat $STATUS_LOG if [ $NODE_KIND == "pumps" ]; then diff --git a/tests/_utils/run_drainer b/tests/_utils/run_drainer index 8e375fc2e..33f73ab26 100755 --- a/tests/_utils/run_drainer +++ b/tests/_utils/run_drainer @@ -22,7 +22,16 @@ echo "[$(date)] <<<<<< START IN TEST ${TEST_NAME-} FOR: $config >>>>>>" >> "$OUT if [ -f "$config" ] then - drainer -log-file $OUT_DIR/drainer.log -config $config $* >> $OUT_DIR/drainer.log 2>&1 -else - drainer -log-file $OUT_DIR/drainer.log $* >> $OUT_DIR/drainer.log 2>&1 + rm -f $OUT_DIR/drainer-config-tmp.toml + cp $config $OUT_DIR/drainer-config-tmp.toml fi + + # Append the TLS config + cat - >> "$OUT_DIR/drainer-config-tmp.toml" <> $OUT_DIR/drainer.log 2>&1 diff --git a/tests/_utils/run_pump b/tests/_utils/run_pump index 3f801fda8..ccd3e7262 100755 --- a/tests/_utils/run_pump +++ b/tests/_utils/run_pump @@ -11,7 +11,11 @@ while : do pump_num=`ps aux > temp && grep "pump -log-file ${OUT_DIR}/pump_${PORT}.log" temp | wc -l && rm temp` if [ $pump_num -ne 0 ]; then - binlogctl -pd-urls 127.0.0.1:2379 -cmd pause-pump -node-id pump:$PORT || true + echo "try pause pump" + binlogctl -ssl-ca $OUT_DIR/cert/ca.pem \ + -ssl-cert $OUT_DIR/cert/client.pem \ + -ssl-key $OUT_DIR/cert/client.key \ + -pd-urls https://127.0.0.1:2379 -cmd pause-pump -node-id pump:$PORT || true sleep 1 else break @@ -20,8 +24,16 @@ done echo "[$(date)] <<<<<< RUNNING pump >>>>>>" >> "$OUT_DIR/pump_$PORT.log" +cat - > "$OUT_DIR/pump-config.toml" <> $OUT_DIR/pump_$PORT.log 2>&1 diff --git a/tests/restart/run.sh b/tests/restart/run.sh index cb76558f8..e0dd0cce1 100755 --- a/tests/restart/run.sh +++ b/tests/restart/run.sh @@ -53,6 +53,9 @@ done echo "data is equal" # offline a pump -binlogctl -pd-urls 127.0.0.1:2379 -cmd offline-pump -node-id pump:8251 +binlogctl -ssl-ca $OUT_DIR/cert/ca.pem \ + -ssl-cert $OUT_DIR/cert/client.pem \ + -ssl-key $OUT_DIR/cert/client.key \ + -pd-urls https://127.0.0.1:2379 -cmd offline-pump -node-id pump:8251 sleep 1 check_status pumps "pump:8251" offline diff --git a/tests/run.sh b/tests/run.sh index 4dd06f59f..c86719dbd 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -13,13 +13,40 @@ pwd=$(pwd) export PATH=$PATH:$pwd/_utils export PATH=$PATH:$(dirname $pwd)/bin +generate_tls_keys() { + # Ref: https://docs.microsoft.com/en-us/azure/application-gateway/self-signed-certificates + # gRPC only supports P-256 curves, see https://github.com/grpc/grpc/issues/6722 + echo "Generate TLS keys..." + TT="$OUT_DIR/cert" + mkdir -p $TT || true + + cat - > "$TT/ipsan.cnf" < /dev/null + + for name in tidb pd tikv pump drainer client; do + openssl ecparam -out "$TT/$name.key" -name prime256v1 -genkey + openssl req -new -batch -sha256 -subj '/CN=localhost' -key "$TT/$name.key" -out "$TT/$name.csr" + openssl x509 -req -sha256 -days 1 -extensions EXT -extfile "$TT/ipsan.cnf" -in "$TT/$name.csr" -CA "$TT/ca.pem" -CAkey "$TT/ca.key" -CAcreateserial -out "$TT/$name.pem" 2> /dev/null + done +} + clean_data() { - rm -rf $OUT_DIR/pd || true - rm -rf $OUT_DIR/tidb || true - rm -rf $OUT_DIR/tikv || true - rm -rf $OUT_DIR/pump || true - rm -rf $OUT_DIR/data.drainer || true + rm -rf $OUT_DIR/* || true } stop_services() { @@ -32,10 +59,22 @@ stop_services() { } start_upstream_tidb() { + cat - > "$OUT_DIR/tidb-config.toml" < "$OUT_DIR/pd-config.toml" < "$OUT_DIR/down-tikv-config.toml" <