diff --git a/internal/observability/helpers.go b/internal/observability/helpers.go new file mode 100644 index 00000000..932f0852 --- /dev/null +++ b/internal/observability/helpers.go @@ -0,0 +1,20 @@ +package observability + +import ( + "context" + "time" + + "go.opencensus.io/tag" +) + +func TagKeyValuesIntoContext(ctx context.Context, key tag.Key, values ...string) (context.Context, error) { + insertions := make([]tag.Mutator, len(values)) + for i, value := range values { + insertions[i] = tag.Insert(key, value) + } + return tag.New(ctx, insertions...) +} + +func SinceInMilliseconds(startTime time.Time) float64 { + return float64(time.Since(startTime).Nanoseconds()) * 1e6 +} diff --git a/internal/observability/observability.go b/internal/observability/observability.go new file mode 100644 index 00000000..ac265281 --- /dev/null +++ b/internal/observability/observability.go @@ -0,0 +1,162 @@ +package observability + +import ( + "go.opencensus.io/stats" + "go.opencensus.io/stats/view" + "go.opencensus.io/tag" +) + +// Pool metrics: +// 1. Connections taken +// 2. Connections closed +// 3. Connections usetime -- how long is a connection used until it is closed, discarded or returned +// 4. Connections reused +// 4. Connections stale +// 5. Dial errors + +const dimensionless = "1" +const milliseconds = "ms" + +var ( + MBytesRead = stats.Int64("redis/bytes_read", "The number of bytes read from the server", stats.UnitBytes) + MBytesWritten = stats.Int64("redis/bytes_written", "The number of bytes written out to the server", stats.UnitBytes) + MDials = stats.Int64("redis/dials", "The number of dials", dimensionless) + MDialErrors = stats.Int64("redis/dial_errors", "The number of dial errors", dimensionless) + MDialLatencyMilliseconds = stats.Float64("redis/dial_latency_milliseconds", "The number of milliseconds spent dialling to the Redis server", dimensionless) + MConnectionsTaken = stats.Int64("redis/connections_taken", "The number of connections taken", dimensionless) + MConnectionsClosed = stats.Int64("redis/connections_closed", "The number of connections closed", dimensionless) + MConnectionsReturned = stats.Int64("redis/connections_returned", "The number of connections returned to the pool", dimensionless) + MConnectionsReused = stats.Int64("redis/connections_reused", "The number of connections reused", dimensionless) + MConnectionsNew = stats.Int64("redis/connections_new", "The number of newly created connections", dimensionless) + MConnectionUseTime = stats.Float64("redis/connection_usetime", "The number of milliseconds for which a connection is used", milliseconds) + MPoolGets = stats.Int64("redis/pool_get_invocations", "The number of times that the connection pool is asked for a connection", dimensionless) + MPoolGetErrors = stats.Int64("redis/pool_get_invocation_errors", "The number of errors encountered when the connection pool is asked for a connection", dimensionless) + MRoundtripLatencyMilliseconds = stats.Float64("redis/roundtrip_latency", "The time in milliseconds between sending the first byte to the server until the last byte of response is received back", milliseconds) + MWriteErrors = stats.Int64("redis/write_errors", "The number of errors encountered during write routines", dimensionless) + MReadErrors = stats.Int64("redis/read_errors", "The number of errors encountered during read routines", dimensionless) + MWrites = stats.Int64("redis/writes", "The number of write invocations", dimensionless) + MReads = stats.Int64("redis/reads", "The number of read invocations", dimensionless) +) + +var KeyCommandName, _ = tag.NewKey("cmd") + +var defaultMillisecondsDistribution = view.Distribution( + // [0ms, 0.001ms, 0.005ms, 0.01ms, 0.05ms, 0.1ms, 0.5ms, 1ms, 1.5ms, 2ms, 2.5ms, 5ms, 10ms, 25ms, 50ms, 100ms, 200ms, 400ms, 600ms, 800ms, 1s, 1.5s, 2.5s, 5s, 10s, 20s, 40s, 100s, 200s, 500s] + 0, 0.000001, 0.000005, 0.00001, 0.00005, 0.0001, 0.0005, 0.001, 0.0015, 0.002, 0.0025, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1, 1.5, 2.5, 5, 10, 20, 40, 100, 200, 500, +) + +var defaultBytesDistribution = view.Distribution( + // [0, 1KB, 2KB, 4KB, 16KB, 64KB, 256KB, 1MB, 4MB, 16MB, 64MB, 256MB, 1GB, 4GB] + 0, 1024, 2048, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, 67108864, 268435456, 1073741824, 4294967296, +) + +var Views = []*view.View{ + { + Name: "redis/client/connection_usetime", + Description: "The duration in milliseconds for which a connection is used before being returned to the pool, closed or discarded", + + Aggregation: defaultMillisecondsDistribution, + Measure: MConnectionUseTime, + }, + { + Name: "redis/client/dial_errors", + Description: "The number of errors encountered after dialling", + Aggregation: view.Count(), + Measure: MDialErrors, + }, + { + Name: "redis/client/dials", + Description: "The number of dials", + Aggregation: view.Count(), + Measure: MDials, + }, + { + Name: "redis/client/dial_latency", + Description: "The number of milliseconds spent dialling", + Aggregation: defaultMillisecondsDistribution, + Measure: MDialLatencyMilliseconds, + }, + { + Name: "redis/client/bytes_written_cumulative", + Description: "The number of bytes written out to the server", + Aggregation: view.Count(), + Measure: MBytesWritten, + }, + { + Name: "redis/client/bytes_written_distribution", + Description: "The number of distribution of bytes written out to the server", + Aggregation: defaultBytesDistribution, + Measure: MBytesWritten, + }, + { + Name: "redis/client/bytes_read_cummulative", + Description: "The number of bytes read from a response from the server", + Aggregation: view.Count(), + Measure: MBytesRead, + }, + { + Name: "redis/client/bytes_read_distribution", + Description: "The number of distribution of bytes read from the server", + Aggregation: defaultBytesDistribution, + Measure: MBytesRead, + }, + { + Name: "redis/client/roundtrip_latency", + Description: "The distribution of milliseconds of the roundtrip latencies for method invocation", + Aggregation: defaultMillisecondsDistribution, + Measure: MRoundtripLatencyMilliseconds, + TagKeys: []tag.Key{KeyCommandName}, + }, + { + Name: "redis/client/write_errors", + Description: "The number of errors encountered during a write routine", + Aggregation: view.Count(), + Measure: MWriteErrors, + TagKeys: []tag.Key{KeyCommandName}, + }, + { + Name: "redis/client/writes", + Description: "The number of write invocations", + Aggregation: view.Count(), + Measure: MWrites, + TagKeys: []tag.Key{KeyCommandName}, + }, + { + Name: "redis/client/reads", + Description: "The number of read invocations", + Aggregation: view.Count(), + Measure: MReads, + TagKeys: []tag.Key{KeyCommandName}, + }, + { + Name: "redis/client/read_errors", + Description: "The number of errors encountered during a read routine", + Aggregation: view.Count(), + Measure: MReadErrors, + TagKeys: []tag.Key{KeyCommandName}, + }, + { + Name: "redis/client/connections_taken", + Description: "The number of connections taken out the pool", + Aggregation: view.Count(), + Measure: MConnectionsTaken, + }, + { + Name: "redis/client/connections_returned", + Description: "The number of connections returned the connection pool", + Aggregation: view.Count(), + Measure: MConnectionsReturned, + }, + { + Name: "redis/client/connections_reused", + Description: "The number of connections reused", + Aggregation: view.Count(), + Measure: MConnectionsReused, + }, + { + Name: "redis/client/connections_new", + Description: "The number of newly created connections", + Aggregation: view.Count(), + Measure: MConnectionsNew, + }, +} diff --git a/redis/conn.go b/redis/conn.go index 5aa0f32f..846b1ecc 100644 --- a/redis/conn.go +++ b/redis/conn.go @@ -17,6 +17,7 @@ package redis import ( "bufio" "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -27,6 +28,11 @@ import ( "strconv" "sync" "time" + + "go.opencensus.io/stats" + "go.opencensus.io/trace" + + "github.com/gomodule/redigo/internal/observability" ) var ( @@ -168,6 +174,21 @@ func DialUseTLS(useTLS bool) DialOption { // Dial connects to the Redis server at the given network and // address using the specified options. func Dial(network, address string, options ...DialOption) (Conn, error) { + return DialWithContext(context.Background(), network, address, options...) +} + +func DialWithContext(ctx context.Context, network, address string, options ...DialOption) (Conn, error) { + startTime := time.Now() + conn, err := doDial(network, address, options...) + measures := []stats.Measurement{observability.MDials.M(1), observability.MDialLatencyMilliseconds.M(observability.SinceInMilliseconds(startTime))} + if err != nil { + measures = append(measures, observability.MDialErrors.M(1)) + } + stats.Record(ctx, measures...) + return conn, err +} + +func doDial(network, address string, options ...DialOption) (Conn, error) { do := dialOptions{ dialer: &net.Dialer{ KeepAlive: time.Minute * 5, @@ -348,42 +369,54 @@ func (c *conn) writeLen(prefix byte, n int) error { return err } -func (c *conn) writeString(s string) error { +func (c *conn) writeString(s string) (int, error) { c.writeLen('$', len(s)) - c.bw.WriteString(s) - _, err := c.bw.WriteString("\r\n") - return err + n, _ := c.bw.WriteString(s) + nr, err := c.bw.WriteString("\r\n") + n += nr + return n, err } -func (c *conn) writeBytes(p []byte) error { +func (c *conn) writeBytes(p []byte) (int, error) { c.writeLen('$', len(p)) c.bw.Write(p) - _, err := c.bw.WriteString("\r\n") - return err + return c.bw.WriteString("\r\n") } -func (c *conn) writeInt64(n int64) error { +func (c *conn) writeInt64(n int64) (int, error) { return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)) } -func (c *conn) writeFloat64(n float64) error { +func (c *conn) writeFloat64(n float64) (int, error) { return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)) } -func (c *conn) writeCommand(cmd string, args []interface{}) error { +func (c *conn) writeCommand(ctx context.Context, cmd string, args []interface{}) (int64, error) { + ctx, span := trace.StartSpan(ctx, "redis.(*Conn).writeCommand") + defer span.End() + c.writeLen('*', 1+len(args)) - if err := c.writeString(cmd); err != nil { - return err + n := int64(0) + ns, err := c.writeString(cmd) + n += int64(ns) + if err != nil { + span.SetStatus(trace.Status{Code: int32(trace.StatusCodeInternal), Message: err.Error()}) + span.End() + return n, err } for _, arg := range args { - if err := c.writeArg(arg, true); err != nil { - return err + ni, err := c.writeArg(arg, true) + if err != nil { + span.End() + return n, err } + n += int64(ni) } - return nil + span.Annotatef([]trace.Attribute{trace.Int64Attribute("bytes_written", n)}, "Wrote bytes") + return n, nil } -func (c *conn) writeArg(arg interface{}, argumentTypeOK bool) (err error) { +func (c *conn) writeArg(arg interface{}, argumentTypeOK bool) (int, error) { switch arg := arg.(type) { case string: return c.writeString(arg) @@ -427,19 +460,19 @@ func (pe protocolError) Error() string { return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe)) } -func (c *conn) readLine() ([]byte, error) { +func (c *conn) readLine() ([]byte, int, error) { p, err := c.br.ReadSlice('\n') if err == bufio.ErrBufferFull { - return nil, protocolError("long response line") + return nil, len(p), protocolError("long response line") } if err != nil { - return nil, err + return nil, len(p), err } i := len(p) - 2 if i < 0 || p[i] != '\r' { - return nil, protocolError("bad response line terminator") + return nil, len(p), protocolError("bad response line terminator") } - return p[:i], nil + return p[:i], len(p), nil } // parseLen parses bulk string and array lengths. @@ -466,9 +499,9 @@ func parseLen(p []byte) (int, error) { } // parseInt parses an integer reply. -func parseInt(p []byte) (interface{}, error) { +func parseInt(p []byte) (interface{}, int, error) { if len(p) == 0 { - return 0, protocolError("malformed integer") + return 0, 0, protocolError("malformed integer") } var negate bool @@ -476,7 +509,7 @@ func parseInt(p []byte) (interface{}, error) { negate = true p = p[1:] if len(p) == 0 { - return 0, protocolError("malformed integer") + return 0, 0, protocolError("malformed integer") } } @@ -484,7 +517,7 @@ func parseInt(p []byte) (interface{}, error) { for _, b := range p { n *= 10 if b < '0' || b > '9' { - return 0, protocolError("illegal bytes in length") + return 0, len(p), protocolError("illegal bytes in length") } n += int64(b - '0') } @@ -492,79 +525,86 @@ func parseInt(p []byte) (interface{}, error) { if negate { n = -n } - return n, nil + return n, len(p), nil } var ( - okReply interface{} = "OK" - pongReply interface{} = "PONG" + okReply string = "OK" + pongReply string = "PONG" ) -func (c *conn) readReply() (interface{}, error) { - line, err := c.readLine() +func (c *conn) readReply() (interface{}, int, error) { + line, n, err := c.readLine() if err != nil { - return nil, err + return nil, n, err } if len(line) == 0 { - return nil, protocolError("short response line") + return nil, n, protocolError("short response line") } switch line[0] { case '+': switch { case len(line) == 3 && line[1] == 'O' && line[2] == 'K': // Avoid allocation for frequent "+OK" response. - return okReply, nil + return okReply, n, nil case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G': // Avoid allocation in PING command benchmarks :) - return pongReply, nil + return pongReply, n, nil default: - return string(line[1:]), nil + return string(line[1:]), n, nil } case '-': - return Error(string(line[1:])), nil + return Error(string(line[1:])), n, nil case ':': return parseInt(line[1:]) case '$': n, err := parseLen(line[1:]) if n < 0 || err != nil { - return nil, err + return nil, n, err } p := make([]byte, n) - _, err = io.ReadFull(c.br, p) + ni, err := io.ReadFull(c.br, p) + ni += n if err != nil { - return nil, err + return nil, ni, err } - if line, err := c.readLine(); err != nil { - return nil, err + if line, nii, err := c.readLine(); err != nil { + return nil, ni + nii, err } else if len(line) != 0 { - return nil, protocolError("bad bulk string format") + return nil, ni, protocolError("bad bulk string format") } - return p, nil + return p, ni, nil case '*': - n, err := parseLen(line[1:]) - if n < 0 || err != nil { - return nil, err + ni, err := parseLen(line[1:]) + if ni < 0 || err != nil { + return nil, n, err } - r := make([]interface{}, n) + r := make([]interface{}, ni) + var nir int for i := range r { - r[i], err = c.readReply() + r[i], nir, err = c.readReply() + n += nir if err != nil { - return nil, err + return nil, n, err } } - return r, nil + return r, n, nil } - return nil, protocolError("unexpected response line") + return nil, n, protocolError("unexpected response line") } func (c *conn) Send(cmd string, args ...interface{}) error { + return c.SendWithContext(context.Background(), cmd, args...) +} + +func (c *conn) SendWithContext(ctx context.Context, cmd string, args ...interface{}) error { c.mu.Lock() c.pending += 1 c.mu.Unlock() if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) } - if err := c.writeCommand(cmd, args); err != nil { + if _, err := c.writeCommand(ctx, cmd, args); err != nil { return c.fatal(err) } return nil @@ -591,7 +631,7 @@ func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err } c.conn.SetReadDeadline(deadline) - if reply, err = c.readReply(); err != nil { + if reply, _, err = c.readReply(); err != nil { return nil, c.fatal(err) } // When using pub/sub, the number of receives can be greater than the @@ -617,6 +657,14 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { } func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) { + return c.do(context.Background(), readTimeout, cmd, args...) +} + +func (c *conn) DoWithContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) { + return c.do(ctx, c.readTimeout, cmd, args...) +} + +func (c *conn) do(ctx context.Context, readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) { c.mu.Lock() pending := c.pending c.pending = 0 @@ -626,31 +674,68 @@ func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...inte return nil, nil } + spanName := cmd + if spanName == "" { + spanName = "do" + } + + ctx, _ = observability.TagKeyValuesIntoContext(ctx, observability.KeyCommandName, spanName) + ctx, span := trace.StartSpan(ctx, "redis.(*Conn)."+spanName) + startTime := time.Now() + defer func() { + // At the very end we need to record the overall latency + span.End() + stats.Record(ctx, observability.MRoundtripLatencyMilliseconds.M(observability.SinceInMilliseconds(startTime))) + }() + if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + span.Annotatef([]trace.Attribute{ + trace.Int64Attribute("timeout_ns", c.writeTimeout.Nanoseconds()), + }, "Set connection writeTimeout") } if cmd != "" { - if err := c.writeCommand(cmd, args); err != nil { + nw, err := c.writeCommand(ctx, cmd, args) + stats.Record(ctx, observability.MBytesWritten.M(nw), observability.MWrites.M(1)) + if err != nil { + stats.Record(ctx, observability.MWriteErrors.M(1)) + span.SetStatus(trace.Status{Code: int32(trace.StatusCodeInternal), Message: err.Error()}) return nil, c.fatal(err) } } if err := c.bw.Flush(); err != nil { + span.SetStatus(trace.Status{Code: int32(trace.StatusCodeInternal), Message: err.Error()}) return nil, c.fatal(err) } var deadline time.Time if readTimeout != 0 { deadline = time.Now().Add(readTimeout) + span.Annotatef([]trace.Attribute{ + trace.Int64Attribute("timeout_ns", readTimeout.Nanoseconds()), + }, "Set connection readTimeout") } c.conn.SetReadDeadline(deadline) + var nread int64 + defer func() { + // At the end record the number of bytes read and increment the number of reads. + stats.Record(ctx, observability.MBytesRead.M(nread), observability.MReads.M(1)) + }() + + _, readSpan := trace.StartSpan(ctx, "redis.(*Conn).readReplies") + defer readSpan.End() + if cmd == "" { reply := make([]interface{}, pending) for i := range reply { - r, e := c.readReply() + r, nir, e := c.readReply() + nread += int64(nir) if e != nil { + readSpan.SetStatus(trace.Status{Code: int32(trace.StatusCodeInternal), Message: e.Error()}) + stats.Record(ctx, observability.MReadErrors.M(1)) return nil, c.fatal(e) } reply[i] = r @@ -662,12 +747,23 @@ func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...inte var reply interface{} for i := 0; i <= pending; i++ { var e error - if reply, e = c.readReply(); e != nil { + var nir int + reply, nir, e = c.readReply() + nread += int64(nir) + if e != nil { + readSpan.SetStatus(trace.Status{Code: int32(trace.StatusCodeInternal), Message: e.Error()}) + stats.Record(ctx, observability.MReadErrors.M(1)) return nil, c.fatal(e) } if e, ok := reply.(Error); ok && err == nil { err = e } } + if err != nil { + stats.Record(ctx, observability.MReadErrors.M(1)) + readSpan.SetStatus(trace.Status{Code: int32(trace.StatusCodeInternal), Message: err.Error()}) + } else { + span.Annotatef([]trace.Attribute{trace.Int64Attribute("bytes_read", nread)}, "Read bytes") + } return reply, err } diff --git a/redis/observability.go b/redis/observability.go new file mode 100644 index 00000000..9e451e2f --- /dev/null +++ b/redis/observability.go @@ -0,0 +1,19 @@ +// Copyright 2018 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import "github.com/gomodule/redigo/internal/observability" + +var ObservabilityMetricViews = observability.Views diff --git a/redis/pool.go b/redis/pool.go index d77da325..cf4b34ec 100644 --- a/redis/pool.go +++ b/redis/pool.go @@ -16,6 +16,7 @@ package redis import ( "bytes" + "context" "crypto/rand" "crypto/sha1" "errors" @@ -26,6 +27,10 @@ import ( "time" "github.com/gomodule/redigo/internal" + "github.com/gomodule/redigo/internal/observability" + + "go.opencensus.io/stats" + "go.opencensus.io/trace" ) var ( @@ -176,11 +181,25 @@ func NewPool(newFn func() (Conn, error), maxIdle int) *Pool { // getting an underlying connection, then the connection Err, Do, Send, Flush // and Receive methods return that error. func (p *Pool) Get() Conn { - pc, err := p.get(nil) + return p.GetWithContext(context.Background()) +} + +func (p *Pool) GetWithContext(ctx context.Context) Conn { + ctx, span := trace.StartSpan(ctx, "redis.(*Pool).Get") + measures := []stats.Measurement{observability.MPoolGets.M(1)} + defer func() { + span.End() + stats.Record(ctx, measures...) + }() + + pc, err := p.get(ctx) if err != nil { + measures = append(measures, observability.MPoolGetErrors.M(1)) + span.SetStatus(trace.Status{Code: int32(trace.StatusCodeInternal), Message: err.Error()}) return errorConn{err} } - return &activeConn{p: p, pc: pc} + measures = append(measures, observability.MConnectionsTaken.M(1)) + return &activeConn{p: p, pc: pc, ctx: ctx} } // PoolStats contains pool statistics. @@ -266,11 +285,7 @@ func (p *Pool) lazyInit() { // get prunes stale connections and returns a connection from the idle list or // creates a new connection. -func (p *Pool) get(ctx interface { - Done() <-chan struct{} - Err() error -}) (*poolConn, error) { - +func (p *Pool) get(ctx context.Context) (*poolConn, error) { // Handle limit for p.Wait == true. if p.Wait && p.MaxActive > 0 { p.lazyInit() @@ -307,6 +322,7 @@ func (p *Pool) get(ctx interface { p.mu.Unlock() if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) && (p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) { + stats.Record(ctx, observability.MConnectionsReused.M(1)) return pc, nil } pc.c.Close() @@ -328,8 +344,11 @@ func (p *Pool) get(ctx interface { p.active++ p.mu.Unlock() + dialStartTime := time.Now() c, err := p.Dial() + measures := []stats.Measurement{observability.MDialLatencyMilliseconds.M(observability.SinceInMilliseconds(dialStartTime))} if err != nil { + measures = append(measures, observability.MDialErrors.M(1)) c = nil p.mu.Lock() p.active-- @@ -338,10 +357,12 @@ func (p *Pool) get(ctx interface { } p.mu.Unlock() } + measures = append(measures, observability.MConnectionsNew.M(1)) + stats.Record(ctx, measures...) return &poolConn{c: c, created: nowFunc()}, err } -func (p *Pool) put(pc *poolConn, forceClose bool) error { +func (p *Pool) put(ctx context.Context, pc *poolConn, forceClose bool) error { p.mu.Lock() if !p.closed && !forceClose { pc.t = nowFunc() @@ -365,6 +386,7 @@ func (p *Pool) put(pc *poolConn, forceClose bool) error { p.ch <- struct{}{} } p.mu.Unlock() + stats.Record(ctx, observability.MConnectionsReturned.M(1)) return nil } @@ -372,6 +394,14 @@ type activeConn struct { p *Pool pc *poolConn state int + ctx context.Context +} + +func (ac *activeConn) context() context.Context { + if ac.ctx == nil { + return context.Background() + } + return ac.ctx } var ( @@ -397,6 +427,7 @@ func (ac *activeConn) Close() error { return nil } ac.pc = nil + stats.Record(ac.context(), observability.MConnectionsClosed.M(1)) if ac.state&internal.MultiState != 0 { pc.c.Send("DISCARD") @@ -425,7 +456,7 @@ func (ac *activeConn) Close() error { } } pc.c.Do("") - ac.p.put(pc, ac.state != 0 || pc.c.Err() != nil) + ac.p.put(ac.context(), pc, ac.state != 0 || pc.c.Err() != nil) return nil } @@ -437,6 +468,10 @@ func (ac *activeConn) Err() error { return pc.c.Err() } +type contextAwareDoer interface { + DoWithContext(context.Context, string, ...interface{}) (interface{}, error) +} + func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { pc := ac.pc if pc == nil { @@ -444,6 +479,10 @@ func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interfa } ci := internal.LookupCommandInfo(commandName) ac.state = (ac.state | ci.Set) &^ ci.Clear + cwdoer, ok := pc.c.(contextAwareDoer) + if ok { + return cwdoer.DoWithContext(ac.context(), commandName, args...) + } return pc.c.Do(commandName, args...) } @@ -502,6 +541,9 @@ func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface type errorConn struct{ err error } func (ec errorConn) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err } +func (ec errorConn) DoWithContext(context.Context, string, ...interface{}) (interface{}, error) { + return nil, ec.err +} func (ec errorConn) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) { return nil, ec.err }