diff --git a/client/client.go b/client/client.go index 2a1064e17fa..aa7a0b41ceb 100644 --- a/client/client.go +++ b/client/client.go @@ -141,23 +141,30 @@ type SecurityOption struct { // NewClient creates a PD client. func NewClient(pdAddrs []string, security SecurityOption) (Client, error) { + return NewClientWithContext(context.Background(), pdAddrs, security) +} + +// NewClientWithContext creates a PD client with context. +func NewClientWithContext(ctx context.Context, pdAddrs []string, security SecurityOption) (Client, error) { log.Info("[pd] create pd client with endpoints", zap.Strings("pd-address", pdAddrs)) - ctx, cancel := context.WithCancel(context.Background()) + ctx1, cancel := context.WithCancel(ctx) c := &client{ urls: addrsToUrls(pdAddrs), tsoRequests: make(chan *tsoRequest, maxMergeTSORequests), tsDeadlineCh: make(chan deadline, 1), checkLeaderCh: make(chan struct{}, 1), - ctx: ctx, + ctx: ctx1, cancel: cancel, security: security, } c.connMu.clientConns = make(map[string]*grpc.ClientConn) if err := c.initRetry(c.initClusterID); err != nil { + cancel() return nil, err } if err := c.initRetry(c.updateLeader); err != nil { + cancel() return nil, err } log.Info("[pd] init cluster id", zap.Uint64("cluster-id", c.clusterID)) @@ -184,7 +191,11 @@ func (c *client) initRetry(f func() error) error { if err = f(); err == nil { return nil } - time.Sleep(time.Second) + select { + case <-c.ctx.Done(): + return err + case <-time.After(time.Second): + } } return errors.WithStack(err) } diff --git a/client/client_test.go b/client/client_test.go index ebf29154873..bec87669b51 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -470,3 +470,16 @@ func (s *testClientSuite) TestScatterRegion(c *C) { }) c.Succeed() } + +var _ = Suite(&testClientCtxSuite{}) + +type testClientCtxSuite struct{} + +func (s *testClientCtxSuite) TestClientCtx(c *C) { + start := time.Now() + ctx, cancel := context.WithTimeout(context.TODO(), time.Second*3) + defer cancel() + _, err := NewClientWithContext(ctx, []string{"localhost:8080"}, SecurityOption{}) + c.Assert(err, NotNil) + c.Assert(time.Since(start), Less, time.Second*4) +}