diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go index f6af1a7b7..f1c9615e0 100644 --- a/pkg/conn/conn.go +++ b/pkg/conn/conn.go @@ -9,7 +9,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "sync" "sync/atomic" "time" @@ -121,50 +120,13 @@ func NewMgr( storeBehavior StoreBehavior, checkRequirements bool, ) (*Mgr, error) { - addrs := strings.Split(pdAddrs, ",") - - failure := errors.Errorf("pd address (%s) has wrong format", pdAddrs) - cli := &http.Client{Timeout: 30 * time.Second} - if tlsConf != nil { - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = tlsConf - cli.Transport = transport - } - - processedAddrs := make([]string, 0, len(addrs)) - for _, addr := range addrs { - if addr != "" && !strings.HasPrefix("http", addr) { - if tlsConf != nil { - addr = "https://" + addr - } else { - addr = "http://" + addr - } - } - processedAddrs = append(processedAddrs, addr) - _, failure = pdRequest(ctx, addr, clusterVersionPrefix, cli, http.MethodGet, nil) - if failure == nil { - break - } - } - if failure != nil { - return nil, errors.Annotatef(failure, "pd address (%s) not available, please check network", pdAddrs) - } - - maxCallMsgSize := []grpc.DialOption{ - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), - grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)), - } - pdClient, err := pd.NewClientWithContext( - ctx, addrs, securityOption, - pd.WithGRPCDialOptions(maxCallMsgSize...), - pd.WithCustomTimeoutOption(10*time.Second), - ) + controller, err := NewPdController(ctx, pdAddrs, tlsConf, securityOption) if err != nil { - log.Error("fail to create pd client", zap.Error(err)) + log.Error("fail to create pd controller", zap.Error(err)) return nil, err } if checkRequirements { - err = utils.CheckClusterVersion(ctx, pdClient) + err = utils.CheckClusterVersion(ctx, controller.pdClient) if err != nil { errMsg := "running BR in incompatible version of cluster, " + "if you believe it's OK, use --check-requirements=false to skip." @@ -174,7 +136,7 @@ func NewMgr( log.Info("new mgr", zap.String("pdAddrs", pdAddrs)) // Check live tikv. - stores, err := GetAllTiKVStores(ctx, pdClient, storeBehavior) + stores, err := GetAllTiKVStores(ctx, controller.pdClient, storeBehavior) if err != nil { log.Error("fail to get store", zap.Error(err)) return nil, err @@ -199,16 +161,12 @@ func NewMgr( } mgr := &Mgr{ - PdController: &PdController{ - pdClient: pdClient, - }, - storage: storage, - dom: dom, - tlsConf: tlsConf, - ownsStorage: g.OwnsStorage(), + PdController: controller, + storage: storage, + dom: dom, + tlsConf: tlsConf, + ownsStorage: g.OwnsStorage(), } - mgr.PdController.addrs = processedAddrs - mgr.PdController.cli = cli mgr.grpcClis.clis = make(map[uint64]*grpc.ClientConn) return mgr, nil } diff --git a/pkg/conn/pd.go b/pkg/conn/pd.go index 1a84dc7c3..5b2196606 100644 --- a/pkg/conn/pd.go +++ b/pkg/conn/pd.go @@ -5,6 +5,7 @@ package conn import ( "bytes" "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -12,6 +13,12 @@ import ( "math" "net/http" "net/url" + "strings" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap" + "google.golang.org/grpc" "github.com/pingcap/errors" pd "github.com/tikv/pd/client" @@ -49,13 +56,6 @@ var ( } ) -// PdController manage get/update config from pd. -type PdController struct { - addrs []string - cli *http.Client - pdClient pd.Client -} - type pdHTTPRequest func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) func pdRequest( @@ -88,6 +88,68 @@ func pdRequest( return r, nil } +// PdController manage get/update config from pd. +type PdController struct { + addrs []string + cli *http.Client + pdClient pd.Client +} + +func NewPdController( + ctx context.Context, + pdAddrs string, + tlsConf *tls.Config, + securityOption pd.SecurityOption, +) (*PdController, error) { + cli := &http.Client{Timeout: 30 * time.Second} + if tlsConf != nil { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = tlsConf + cli.Transport = transport + } + + addrs := strings.Split(pdAddrs, ",") + processedAddrs := make([]string, 0, len(addrs)) + var failure error + for _, addr := range addrs { + if addr != "" && !strings.HasPrefix("http", addr) { + if tlsConf != nil { + addr = "https://" + addr + } else { + addr = "http://" + addr + } + } + processedAddrs = append(processedAddrs, addr) + _, failure = pdRequest(ctx, addr, clusterVersionPrefix, cli, http.MethodGet, nil) + if failure == nil { + break + } + } + if failure != nil { + return nil, errors.Annotatef(failure, "pd address (%s) not available, please check network", pdAddrs) + } + + maxCallMsgSize := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), + grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)), + } + pdClient, err := pd.NewClientWithContext( + ctx, addrs, securityOption, + pd.WithGRPCDialOptions(maxCallMsgSize...), + pd.WithCustomTimeoutOption(10*time.Second), + ) + if err != nil { + log.Error("fail to create pd client", zap.Error(err)) + return nil, err + } + + return &PdController{ + addrs: addrs, + cli: cli, + pdClient: pdClient, + }, nil +} + // RemoveScheduler remove pd scheduler. func (p *PdController) RemoveScheduler(ctx context.Context, scheduler string) error { return p.removeSchedulerWith(ctx, scheduler, pdRequest) diff --git a/pkg/conn/pd_test.go b/pkg/conn/pd_test.go index 1a3d81be4..d7b42fd7f 100644 --- a/pkg/conn/pd_test.go +++ b/pkg/conn/pd_test.go @@ -3,44 +3,46 @@ package conn import ( "context" "errors" + "fmt" "io" "net/http" . "github.com/pingcap/check" ) -type testPdMgrSuite struct { +type testPDControllerSuite struct { } -var _ = Suite(&testPdMgrSuite{}) +var _ = Suite(&testPDControllerSuite{}) -func (s *testPdMgrSuite) TestScheduler(c *C) { +func (s *testPDControllerSuite) TestScheduler(c *C) { ctx := context.Background() scheduler := "balance-leader-scheduler" mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { return nil, errors.New("failed") } - pdMgr := &PdController{} - err := pdMgr.removeSchedulerWith(ctx, scheduler, mock) + pdController := &PdController{addrs: []string{"", ""}} + err := pdController.removeSchedulerWith(ctx, scheduler, mock) + fmt.Printf("err: %v\n", err) c.Assert(err, ErrorMatches, "failed") - err = pdMgr.addSchedulerWith(ctx, scheduler, mock) + err = pdController.addSchedulerWith(ctx, scheduler, mock) c.Assert(err, ErrorMatches, "failed") - _, err = pdMgr.listSchedulersWith(ctx, mock) + _, err = pdController.listSchedulersWith(ctx, mock) c.Assert(err, ErrorMatches, "failed") mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { return []byte(`["` + scheduler + `"]`), nil } - err = pdMgr.removeSchedulerWith(ctx, scheduler, mock) + err = pdController.removeSchedulerWith(ctx, scheduler, mock) c.Assert(err, IsNil) - err = pdMgr.addSchedulerWith(ctx, scheduler, mock) + err = pdController.addSchedulerWith(ctx, scheduler, mock) c.Assert(err, IsNil) - schedulers, err := pdMgr.listSchedulersWith(ctx, mock) + schedulers, err := pdController.listSchedulersWith(ctx, mock) c.Assert(err, IsNil) c.Assert(schedulers, HasLen, 1) c.Assert(schedulers[0], Equals, scheduler)