Skip to content

Commit

Permalink
client, conn: remove idle timeout goroutine and use net.Conn read/wri…
Browse files Browse the repository at this point in the history
…te deadlines instead (perlin-network#270)

* client, conn: remove idle timeout goroutine and use net.Conn read/write deadlines instead

* node/test: fix idle timeout test so that AB or BA may either yield a timeout error, or an io.EOF
  • Loading branch information
iwasaki-kenta committed Jan 31, 2020
1 parent 1e706da commit 3878fcc
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 66 deletions.
67 changes: 5 additions & 62 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"io"
"net"
"sync"
"time"
)

type clientSide bool
Expand All @@ -33,10 +32,9 @@ const (
// The lifecycle of a client may be controlled through (*Client).WaitUntilReady, and (*Client).WaitUntilClosed. It
// provably has been useful in writing unit tests where a client instance is used under high concurrency scenarios.
//
// A client in total has four goroutines associated to it: a goroutine responsible for handling writing messages, a
// goroutine responsible for handling the recipient of messages, a goroutine for timing out the client connection
// should there be no further read/writes after some configured timeout on the clients associatd node, and a goroutine
// for handling protocol logic such as handshaking/executing Handler's.
// A client in total has three goroutines associated to it: a goroutine responsible for handling writing messages, a
// goroutine responsible for handling the recipient of messages, and a goroutine for handling protocol logic such as
// handshaking/executing Handler's.
type Client struct {
node *Node

Expand All @@ -53,11 +51,6 @@ type Client struct {
*zap.Logger
}

timeout struct {
reset chan struct{}
timer *time.Timer
}

reader *connReader
writer *connWriter

Expand Down Expand Up @@ -191,46 +184,6 @@ func (c *Client) waitUntilClosed() {
<-c.clientDone
}

func (c *Client) startTimeout(ctx context.Context) {
c.timeout.reset = make(chan struct{}, 1)

if c.node.idleTimeout == 0 {
return
}

c.timeout.timer = time.NewTimer(c.node.idleTimeout)

go func() {
defer c.timeout.timer.Stop()

for {
select {
case <-ctx.Done():
return
case <-c.clientDone:
return
case <-c.timeout.reset:
if !c.timeout.timer.Stop() {
<-c.timeout.timer.C
}

c.timeout.timer.Reset(c.node.idleTimeout)
case <-c.timeout.timer.C:
c.reportError(context.DeadlineExceeded)
c.close()
return
}
}
}()
}

func (c *Client) resetTimeout() {
select {
case c.timeout.reset <- struct{}{}:
default:
}
}

func (c *Client) outbound(ctx context.Context, addr string) {
c.addr = addr
c.side = clientSideInbound
Expand All @@ -257,7 +210,6 @@ func (c *Client) outbound(ctx context.Context, addr string) {
_ = conn.(*net.TCPConn).SetReadBuffer(10000)

c.conn = conn
c.startTimeout(ctx)

go c.readLoop(conn)
go c.writeLoop(conn)
Expand Down Expand Up @@ -288,7 +240,6 @@ func (c *Client) inbound(conn net.Conn, addr string) {
}()

ctx := context.Background()
c.startTimeout(ctx)

go c.readLoop(conn)
go c.writeLoop(conn)
Expand Down Expand Up @@ -327,8 +278,6 @@ func (c *Client) request(ctx context.Context, data []byte) (message, error) {
return message{}, err
}

c.resetTimeout()

// Await response.

var msg message
Expand All @@ -339,8 +288,6 @@ func (c *Client) request(ctx context.Context, data []byte) (message, error) {
return message{}, ctx.Err()
}

c.resetTimeout()

return msg, nil
}

Expand All @@ -357,8 +304,6 @@ func (c *Client) send(nonce uint64, data []byte) error {

c.writer.write(data)

c.resetTimeout()

return nil
}

Expand All @@ -382,8 +327,6 @@ func (c *Client) recv(ctx context.Context) (message, error) {
return message{}, err
}

c.resetTimeout()

return msg, nil
case <-ctx.Done():
return message{}, ctx.Err()
Expand Down Expand Up @@ -571,7 +514,7 @@ func (c *Client) handleLoop() {
func (c *Client) writeLoop(conn net.Conn) {
defer close(c.writerDone)

if err := c.writer.loop(conn); err != nil {
if err := c.writer.loop(conn, c.node.idleTimeout); err != nil {
if !isEOF(err) {
c.Logger().Warn("Got an error while sending messages.", zap.Error(err))
}
Expand All @@ -583,7 +526,7 @@ func (c *Client) writeLoop(conn net.Conn) {
func (c *Client) readLoop(conn net.Conn) {
defer close(c.readerDone)

if err := c.reader.loop(conn); err != nil {
if err := c.reader.loop(conn, c.node.idleTimeout); err != nil {
if !isEOF(err) {
c.Logger().Warn("Got an error while reading incoming messages.", zap.Error(err))
}
Expand Down
17 changes: 15 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net"
"sync"
"time"
)

type connWriterState byte
Expand Down Expand Up @@ -59,7 +60,7 @@ func (c *connWriter) write(data []byte) {
c.cond.Broadcast()
}

func (c *connWriter) loop(conn net.Conn) error {
func (c *connWriter) loop(conn net.Conn, timeout time.Duration) error {
c.Lock()
c.state = connWriterRunning
c.Unlock()
Expand All @@ -76,6 +77,12 @@ func (c *connWriter) loop(conn net.Conn) error {
}()

for {
if timeout > 0 {
if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
return err
}
}

c.Lock()
for c.state == connWriterRunning && len(c.pending) == 0 {
c.cond.Wait()
Expand Down Expand Up @@ -114,13 +121,19 @@ func newConnReader() *connReader {
return &connReader{pending: make(chan []byte, 1024)}
}

func (c *connReader) loop(conn net.Conn) error {
func (c *connReader) loop(conn net.Conn, timeout time.Duration) error {
defer close(c.pending)

header := make([]byte, 4)
reader := bufio.NewReader(conn)

for {
if timeout > 0 {
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return err
}
}

if _, err := io.ReadFull(reader, header); err != nil {
return err
}
Expand Down
9 changes: 7 additions & 2 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"io"
"net"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -272,7 +273,9 @@ func TestIdleTimeoutServerSide(t *testing.T) {
ba.WaitUntilClosed()
ab.WaitUntilClosed()

assert.EqualValues(t, ab.Error(), context.DeadlineExceeded)
var abError *net.OpError
assert.True(t, errors.As(ab.Error(), &abError) && abError.Timeout() || ab.Error() == io.EOF)

assert.EqualValues(t, ba.Error(), io.EOF)

assert.Len(t, a.Inbound(), 0)
Expand Down Expand Up @@ -306,7 +309,9 @@ func TestIdleTimeoutClientSide(t *testing.T) {
ba.WaitUntilClosed()
ab.WaitUntilClosed()

assert.EqualValues(t, ba.Error(), context.DeadlineExceeded)
var baError *net.OpError
assert.True(t, errors.As(ba.Error(), &baError) && baError.Timeout() || ba.Error() == io.EOF)

assert.EqualValues(t, ab.Error(), io.EOF)

assert.Len(t, a.Inbound(), 0)
Expand Down

0 comments on commit 3878fcc

Please sign in to comment.