Skip to content

Commit

Permalink
Merge pull request #1046 from go-redis/fix/dialer-context
Browse files Browse the repository at this point in the history
Pass context to Dialer
  • Loading branch information
vmihailenco authored Jun 8, 2019
2 parents 6d7c742 + 4eb0643 commit b0bd138
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 60 deletions.
6 changes: 3 additions & 3 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type ClusterOptions struct {

// Following options are copied from Options struct.

Dialer func(network, addr string) (net.Conn, error)
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)

OnConnect func(*Conn) error

Expand Down Expand Up @@ -1055,7 +1055,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
go func(node *clusterNode, cmds []Cmder) {
defer wg.Done()

cn, err := node.Client.getConn()
cn, err := node.Client.getConn(ctx)
if err != nil {
if err == pool.ErrClosed {
c.mapCmdsByNode(cmds, failedCmds)
Expand Down Expand Up @@ -1256,7 +1256,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
go func(node *clusterNode, cmds []Cmder) {
defer wg.Done()

cn, err := node.Client.getConn()
cn, err := node.Client.getConn(ctx)
if err != nil {
if err == pool.ErrClosed {
c.mapCmdsByNode(cmds, failedCmds)
Expand Down
4 changes: 2 additions & 2 deletions internal/pool/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func BenchmarkPoolGetPut(b *testing.B) {

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
if err != nil {
b.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/pool/main_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pool_test

import (
"context"
"net"
"sync"
"testing"
Expand Down Expand Up @@ -30,6 +31,6 @@ func perform(n int, cbs ...func(int)) {
wg.Wait()
}

func dummyDialer() (net.Conn, error) {
func dummyDialer(context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
}
23 changes: 12 additions & 11 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pool

import (
"context"
"errors"
"net"
"sync"
Expand Down Expand Up @@ -36,7 +37,7 @@ type Pooler interface {
NewConn() (*Conn, error)
CloseConn(*Conn) error

Get() (*Conn, error)
Get(context.Context) (*Conn, error)
Put(*Conn)
Remove(*Conn)

Expand All @@ -48,7 +49,7 @@ type Pooler interface {
}

type Options struct {
Dialer func() (net.Conn, error)
Dialer func(c context.Context) (net.Conn, error)
OnClose func(*Conn) error

PoolSize int
Expand Down Expand Up @@ -114,7 +115,7 @@ func (p *ConnPool) checkMinIdleConns() {
}

func (p *ConnPool) addIdleConn() {
cn, err := p.newConn(true)
cn, err := p.newConn(nil, true)
if err != nil {
return
}
Expand All @@ -126,11 +127,11 @@ func (p *ConnPool) addIdleConn() {
}

func (p *ConnPool) NewConn() (*Conn, error) {
return p._NewConn(false)
return p._NewConn(nil, false)
}

func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
cn, err := p.newConn(pooled)
func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(c, pooled)
if err != nil {
return nil, err
}
Expand All @@ -148,7 +149,7 @@ func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
return cn, nil
}

func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
Expand All @@ -157,7 +158,7 @@ func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}

netConn, err := p.opt.Dialer()
netConn, err := p.opt.Dialer(c)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
Expand All @@ -177,7 +178,7 @@ func (p *ConnPool) tryDial() {
return
}

conn, err := p.opt.Dialer()
conn, err := p.opt.Dialer(nil)
if err != nil {
p.setLastDialError(err)
time.Sleep(time.Second)
Expand All @@ -204,7 +205,7 @@ func (p *ConnPool) getLastDialError() error {
}

// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (*Conn, error) {
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
Expand Down Expand Up @@ -234,7 +235,7 @@ func (p *ConnPool) Get() (*Conn, error) {

atomic.AddUint32(&p.stats.Misses, 1)

newcn, err := p._NewConn(true)
newcn, err := p._NewConn(c, true)
if err != nil {
p.freeTurn()
return nil, err
Expand Down
4 changes: 3 additions & 1 deletion internal/pool/pool_single.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package pool

import "context"

type SingleConnPool struct {
cn *Conn
}
Expand All @@ -20,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented")
}

func (p *SingleConnPool) Get() (*Conn, error) {
func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
return p.cn, nil
}

Expand Down
9 changes: 6 additions & 3 deletions internal/pool/pool_sticky.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package pool

import "sync"
import (
"context"
"sync"
)

type StickyConnPool struct {
pool *ConnPool
Expand Down Expand Up @@ -28,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented")
}

func (p *StickyConnPool) Get() (*Conn, error) {
func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
p.mu.Lock()
defer p.mu.Unlock()

Expand All @@ -39,7 +42,7 @@ func (p *StickyConnPool) Get() (*Conn, error) {
return p.cn, nil
}

cn, err := p.pool.Get()
cn, err := p.pool.Get(c)
if err != nil {
return nil, err
}
Expand Down
24 changes: 12 additions & 12 deletions internal/pool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ var _ = Describe("ConnPool", func() {

It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())

// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
Expand All @@ -47,7 +47,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover()

started <- true
_, err := connPool.Get()
_, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
done <- true

Expand Down Expand Up @@ -110,7 +110,7 @@ var _ = Describe("MinIdleConns", func() {

BeforeEach(func() {
var err error
cn, err = connPool.Get()
cn, err = connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())

Eventually(func() int {
Expand Down Expand Up @@ -145,7 +145,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) {
defer GinkgoRecover()

cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
mu.Lock()
cns = append(cns, cn)
Expand All @@ -160,7 +160,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() {
done := make(chan struct{})
go func() {
connPool.Get()
connPool.Get(nil)
close(done)
}()

Expand Down Expand Up @@ -274,7 +274,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections
staleConns = nil
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
switch typ {
case "idle":
Expand All @@ -288,7 +288,7 @@ var _ = Describe("conns reaper", func() {

// add fresh connections
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn)
}
Expand Down Expand Up @@ -333,7 +333,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ {
var freeCns []*pool.Conn
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn)
Expand All @@ -342,7 +342,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0))

cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
conns = append(conns, cn)
Expand Down Expand Up @@ -396,15 +396,15 @@ var _ = Describe("race", func() {

perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Put(cn)
}
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(cn)
Expand Down
9 changes: 5 additions & 4 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redis

import (
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -34,7 +35,7 @@ type Options struct {

// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func(network, addr string) (net.Conn, error)
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)

// Hook that is called when new connection is established.
OnConnect func(*Conn) error
Expand Down Expand Up @@ -105,7 +106,7 @@ func (opt *Options) init() {
opt.Addr = "localhost:6379"
}
if opt.Dialer == nil {
opt.Dialer = func(network, addr string) (net.Conn, error) {
opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: opt.DialTimeout,
KeepAlive: 5 * time.Minute,
Expand Down Expand Up @@ -215,8 +216,8 @@ func ParseURL(redisURL string) (*Options, error) {

func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(&pool.Options{
Dialer: func() (net.Conn, error) {
return opt.Dialer(opt.Network, opt.Addr)
Dialer: func(c context.Context) (net.Conn, error) {
return opt.Dialer(c, opt.Network, opt.Addr)
},
PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns,
Expand Down
2 changes: 1 addition & 1 deletion pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ var _ = Describe("pool", func() {
})

It("removes broken connections", func() {
cn, err := client.Pool().Get()
cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
client.Pool().Put(cn)
Expand Down
Loading

0 comments on commit b0bd138

Please sign in to comment.