Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support TLS for components and downstream db #931

Merged
merged 4 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arbiter/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewServer(cfg *Config) (srv *Server, err error) {
up := cfg.Up
down := cfg.Down

srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port)
srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port, nil)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
7 changes: 4 additions & 3 deletions arbiter/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package arbiter

import (
"context"
"crypto/tls"
"database/sql"
"fmt"
"sync"
Expand Down Expand Up @@ -55,7 +56,7 @@ func (l *dummyLoader) Close() {
type testNewServerSuite struct {
db *sql.DB
dbMock sqlmock.Sqlmock
origCreateDB func(string, string, string, int) (*sql.DB, error)
origCreateDB func(string, string, string, int, *tls.Config) (*sql.DB, error)
origNewReader func(*reader.Config) (*reader.Reader, error)
origNewLoader func(*sql.DB, ...loader.Option) (loader.Loader, error)
}
Expand All @@ -71,7 +72,7 @@ func (s *testNewServerSuite) SetUpTest(c *C) {
s.dbMock = mock

s.origCreateDB = createDB
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return s.db, nil
}

Expand Down Expand Up @@ -105,7 +106,7 @@ func (s *testNewServerSuite) TestRejectInvalidAddr(c *C) {
}

func (s *testNewServerSuite) TestStopIfFailedtoConnectDownStream(c *C) {
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return nil, fmt.Errorf("Can't create db")
}

Expand Down
30 changes: 15 additions & 15 deletions binlogctl/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,20 @@ const (

// Config holds the configuration of drainer
type Config struct {
*flag.FlagSet

Command string `toml:"cmd" json:"cmd"`
NodeID string `toml:"node-id" json:"node-id"`
DataDir string `toml:"data-dir" json:"data-dir"`
TimeZone string `toml:"time-zone" json:"time-zone"`
EtcdURLs string `toml:"pd-urls" json:"pd-urls"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
State string `toml:"state" json:"state"`
ShowOfflineNodes bool `toml:"state" json:"show-offline-nodes"`
Text string `toml:"text" json:"text"`
tls *tls.Config
*flag.FlagSet `toml:"-" json:"-"`

Command string `toml:"cmd" json:"cmd"`
NodeID string `toml:"node-id" json:"node-id"`
DataDir string `toml:"data-dir" json:"data-dir"`
TimeZone string `toml:"time-zone" json:"time-zone"`
EtcdURLs string `toml:"pd-urls" json:"pd-urls"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
State string `toml:"state" json:"state"`
ShowOfflineNodes bool `toml:"state" json:"show-offline-nodes"`
Text string `toml:"text" json:"text"`
TLS *tls.Config `toml:"-" json:"-"`
printVersion bool
}

Expand Down Expand Up @@ -134,7 +134,7 @@ func (cfg *Config) Parse(args []string) error {
SSLCert: cfg.SSLCert,
SSLKey: cfg.SSLKey,
}
cfg.tls, err = sCfg.ToTLSConfig()
cfg.TLS, err = sCfg.ToTLSConfig()
if err != nil {
return errors.Errorf("tls config error %v", err)
}
Expand Down
39 changes: 28 additions & 11 deletions binlogctl/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package binlogctl

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"time"
Expand All @@ -34,8 +35,8 @@ var (
)

// QueryNodesByKind returns specified nodes, like pumps/drainers
func QueryNodesByKind(urls string, kind string, showOffline bool) error {
registry, err := createRegistryFuc(urls)
func QueryNodesByKind(urls string, kind string, showOffline bool, tlsConfig *tls.Config) error {
registry, err := createRegistryFuc(urls, tlsConfig)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -56,12 +57,12 @@ func QueryNodesByKind(urls string, kind string, showOffline bool) error {
}

// UpdateNodeState update pump or drainer's state.
func UpdateNodeState(urls, kind, nodeID, state string) error {
func UpdateNodeState(urls, kind, nodeID, state string, tlsConfig *tls.Config) error {
/*
node's state can be online, pausing, paused, closing and offline.
if the state is one of them, will update the node's state saved in etcd directly.
*/
registry, err := createRegistryFuc(urls)
registry, err := createRegistryFuc(urls, tlsConfig)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -81,12 +82,12 @@ func UpdateNodeState(urls, kind, nodeID, state string) error {
}

// createRegistry returns an ectd registry
func createRegistry(urls string) (*node.EtcdRegistry, error) {
func createRegistry(urls string, tlsConfig *tls.Config) (*node.EtcdRegistry, error) {
ectdEndpoints, err := flags.ParseHostPortAddr(urls)
if err != nil {
return nil, errors.Trace(err)
}
cli, err := newEtcdClientFromCfgFunc(ectdEndpoints, etcdDialTimeout, node.DefaultRootPath, nil)
cli, err := newEtcdClientFromCfgFunc(ectdEndpoints, etcdDialTimeout, node.DefaultRootPath, tlsConfig)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -95,8 +96,8 @@ func createRegistry(urls string) (*node.EtcdRegistry, error) {
}

// ApplyAction applies action on pump or drainer
func ApplyAction(urls, kind, nodeID string, action string) error {
registry, err := createRegistryFuc(urls)
func ApplyAction(urls, kind, nodeID string, action string, tlsConfig *tls.Config) error {
registry, err := createRegistryFuc(urls, tlsConfig)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -106,18 +107,34 @@ func ApplyAction(urls, kind, nodeID string, action string) error {
return errors.Trace(err)
}

var client http.Client
url := fmt.Sprintf("http://%s/state/%s/%s", n.Addr, n.NodeID, action)
schema := "http"
if tlsConfig != nil {
schema = "https"
}

url := fmt.Sprintf("%s://%s/state/%s/%s", schema, n.Addr, n.NodeID, action)
log.Debug("send put http request", zap.String("url", url))
req, err := http.NewRequest("PUT", url, nil)
if err != nil {
return errors.Trace(err)
}
_, err = client.Do(req)
_, err = getClient(tlsConfig).Do(req)
if err == nil {
log.Info("Apply action on node success", zap.String("action", action), zap.String("NodeID", n.NodeID))
return nil
}

return errors.Trace(err)
}

func getClient(tlsConfig *tls.Config) *http.Client {
if tlsConfig == nil {
return &http.Client{}
}

return &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
}
}
16 changes: 8 additions & 8 deletions binlogctl/nodes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type testNodesSuite struct{}
func (s *testNodesSuite) SetUpTest(c *C) {
newEtcdClientFromCfgFunc = newFakeEtcdClientFromCfg
createRegistryFuc = createMockRegistry
_, err := createMockRegistry("127.0.0.1:2379")
_, err := createMockRegistry("127.0.0.1:2379", nil)
c.Assert(err, IsNil)
}

Expand All @@ -63,29 +63,29 @@ func (s *testNodesSuite) TestApplyAction(c *C) {

registerPumpForTest(c, "test", url)

err := ApplyAction("127.0.0.1:2379", "pumps", "test2", PausePump)
err := ApplyAction("127.0.0.1:2379", "pumps", "test2", PausePump, nil)
c.Assert(errors.IsNotFound(err), IsTrue)

// TODO: handle log information and add check
err = ApplyAction("127.0.0.1:2379", "pumps", "test", PausePump)
err = ApplyAction("127.0.0.1:2379", "pumps", "test", PausePump, nil)
c.Assert(err, IsNil)
}

func (s *testNodesSuite) TestQueryNodesByKind(c *C) {
registerPumpForTest(c, "test", "127.0.0.1:8255")

// TODO: handle log information and add check
err := QueryNodesByKind("127.0.0.1:2379", "pumps", false)
err := QueryNodesByKind("127.0.0.1:2379", "pumps", false, nil)
c.Assert(err, IsNil)
}

func (s *testNodesSuite) TestUpdateNodeState(c *C) {
registerPumpForTest(c, "test", "127.0.0.1:8255")

err := UpdateNodeState("127.0.0.1:2379", "pumps", "test2", node.Paused)
err := UpdateNodeState("127.0.0.1:2379", "pumps", "test2", node.Paused, nil)
c.Assert(err, ErrorMatches, ".*not found.*")

err = UpdateNodeState("127.0.0.1:2379", "pumps", "test", node.Paused)
err = UpdateNodeState("127.0.0.1:2379", "pumps", "test", node.Paused, nil)
c.Assert(err, IsNil)

// check node's state is changed to paused
Expand All @@ -104,7 +104,7 @@ func (s *testNodesSuite) TestUpdateNodeState(c *C) {

func (s *testNodesSuite) TestCreateRegistry(c *C) {
urls := "127.0.0.1:2379"
registry, err := createRegistry(urls)
registry, err := createRegistry(urls, nil)
c.Assert(err, IsNil)
c.Assert(registry, NotNil)

Expand All @@ -131,7 +131,7 @@ func (s *testNodesSuite) TestCreateRegistry(c *C) {

}

func createMockRegistry(urls string) (*node.EtcdRegistry, error) {
func createMockRegistry(urls string, _ *tls.Config) (*node.EtcdRegistry, error) {
if fakeRegistry != nil {
return fakeRegistry, nil
}
Expand Down
16 changes: 8 additions & 8 deletions cmd/binlogctl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,21 @@ func main() {
case ctl.GenerateMeta:
err = ctl.GenerateMetaInfo(cfg)
case ctl.QueryPumps:
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.PumpNode, cfg.ShowOfflineNodes)
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.PumpNode, cfg.ShowOfflineNodes, cfg.TLS)
case ctl.QueryDrainers:
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.DrainerNode, cfg.ShowOfflineNodes)
err = ctl.QueryNodesByKind(cfg.EtcdURLs, node.DrainerNode, cfg.ShowOfflineNodes, cfg.TLS)
case ctl.UpdatePump:
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, cfg.State)
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, cfg.State, cfg.TLS)
case ctl.UpdateDrainer:
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, cfg.State)
err = ctl.UpdateNodeState(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, cfg.State, cfg.TLS)
case ctl.PausePump:
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, pause)
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, pause, cfg.TLS)
case ctl.PauseDrainer:
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, pause)
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, pause, cfg.TLS)
case ctl.OfflinePump:
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, close)
err = ctl.ApplyAction(cfg.EtcdURLs, node.PumpNode, cfg.NodeID, close, cfg.TLS)
case ctl.OfflineDrainer:
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, close)
err = ctl.ApplyAction(cfg.EtcdURLs, node.DrainerNode, cfg.NodeID, close, cfg.TLS)
case ctl.Encrypt:
if len(cfg.Text) == 0 {
err = errors.New("need to specify the text to be encrypt")
Expand Down
10 changes: 10 additions & 0 deletions cmd/drainer/drainer.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ port = 3306
# when setting SyncPartialColumn drainer will allow the downstream schema
# having more or less column numbers and relax sql mode by removing STRICT_TRANS_TABLES.
# sync-mode = 1
#
# Uncomment this part if you need TLS to connecting downstream MySQL/TiDB.
# You can only specified only `ssl-ca` if there is no client certificate and don't need server to authenticate client.
# [syncer.to.security]
# Path of file that contains list of trusted SSL CAs.
# ssl-ca = "/path/to/ca.pem"
# Path of file that contains X509 certificate in PEM format.
# ssl-cert = "/path/to/drainer.pem"
# Path of file that contains X509 key in PEM format.
# ssl-key = "/path/to/drainer-key.pem"

[syncer.to.checkpoint]
# only support mysql or tidb now, you can uncomment this to control where the checkpoint is saved.
Expand Down
1 change: 1 addition & 0 deletions cmd/pump/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func main() {
log.Fatal("Failed to initialize log", zap.Error(err))
}
version.PrintVersionInfo("Pump")
log.Info("start pump...", zap.Reflect("config", cfg))

p, err := pump.NewServer(cfg)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion drainer/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package drainer

import (
"crypto/tls"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -49,6 +50,7 @@ type notifyResult struct {
// Collector collects binlog from all pump, and send binlog to syncer.
type Collector struct {
clusterID uint64
tls *tls.Config
interval time.Duration
reg *node.EtcdRegistry
tiStore kv.Storage
Expand Down Expand Up @@ -106,6 +108,7 @@ func NewCollector(cfg *Config, clusterID uint64, s *Syncer, cpt checkpoint.Check

c := &Collector{
clusterID: clusterID,
tls: cfg.tls,
interval: time.Duration(cfg.DetectInterval) * time.Second,
reg: node.NewEtcdRegistry(cli, cfg.EtcdTimeout),
pumps: make(map[string]*Pump),
Expand Down Expand Up @@ -308,7 +311,7 @@ func (c *Collector) handlePumpStatusUpdate(ctx context.Context, n *node.Status)
}

commitTS := c.merger.GetLatestTS()
p := NewPump(n.NodeID, n.Addr, c.clusterID, commitTS, c.errCh)
p := NewPump(n.NodeID, n.Addr, c.tls, c.clusterID, commitTS, c.errCh)
c.pumps[n.NodeID] = p
c.merger.AddSource(MergeSource{
ID: n.NodeID,
Expand Down
7 changes: 7 additions & 0 deletions drainer/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ func (cfg *Config) Parse(args []string) error {
return errors.Errorf("tls config %+v error %v", cfg.Security, err)
}

if cfg.SyncerCfg != nil && cfg.SyncerCfg.To != nil {
cfg.SyncerCfg.To.TLS, err = cfg.SyncerCfg.To.Security.ToTLSConfig()
if err != nil {
return errors.Errorf("tls config %+v error %v", cfg.SyncerCfg.To.Security, err)
}
}

if err = cfg.adjustConfig(); err != nil {
return errors.Trace(err)
}
Expand Down
Loading