Skip to content

Commit

Permalink
Merge pull request perlin-network#213 from perlin-network/fix-peer-di…
Browse files Browse the repository at this point in the history
…sconnect

Fix a bug that happens when peer's connection is closed by the dialler.
  • Loading branch information
iwasaki-kenta committed Mar 5, 2019
2 parents e7b0ac2 + d058e32 commit f9c1f15
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 10 deletions.
4 changes: 3 additions & 1 deletion node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ func TestCallbacks(t *testing.T) {

clearCounters()

<-peer.DisconnectAsync()
peer.Disconnect()

time.Sleep(10 * time.Millisecond)

// check that the expected callbacks were called on the dialer
compareCB(callbacks[src], map[string]int{
Expand Down
16 changes: 8 additions & 8 deletions peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ func (p *Peer) spawnReceiveWorker() {
p.onConnErrorCallbacks.RunCallbacks(p.node, errors.Wrap(err, "failed to read message size"))
}

p.Disconnect()
p.DisconnectAsync()
continue
}

if size > p.node.maxMessageSize {
p.onConnErrorCallbacks.RunCallbacks(p.node, errors.Errorf("exceeded max message size; got size %d", size))

p.Disconnect()
p.DisconnectAsync()
continue
}

Expand All @@ -202,22 +202,22 @@ func (p *Peer) spawnReceiveWorker() {
if err != nil {
p.onConnErrorCallbacks.RunCallbacks(p.node, errors.Wrap(err, "failed to read remaining message contents"))

p.Disconnect()
p.DisconnectAsync()
continue
}

if seen < int(size) {
p.onConnErrorCallbacks.RunCallbacks(p.node, errors.Errorf("only read %d bytes when expected to read %d from peer", seen, size))

p.Disconnect()
p.DisconnectAsync()
continue
}

b, errs := p.beforeMessageReceivedCallbacks.RunCallbacks(buf, p.node)
if len(errs) > 0 {
log.Warn().Errs("errors", errs).Msg("Got errors running BeforeMessageReceived callbacks.")

p.Disconnect()
p.DisconnectAsync()
continue
}
buf = b.([]byte)
Expand All @@ -227,7 +227,7 @@ func (p *Peer) spawnReceiveWorker() {
if opcode == OpcodeNil || err != nil {
p.onConnErrorCallbacks.RunCallbacks(p.node, errors.Wrap(err, "failed to decode message"))

p.Disconnect()
p.DisconnectAsync()
continue
}

Expand All @@ -239,14 +239,14 @@ func (p *Peer) spawnReceiveWorker() {
recv.lock <- struct{}{}
<-recv.lock
case <-time.After(p.node.receiveMessageTimeout):
p.Disconnect()
p.DisconnectAsync()
continue
}

if errs := p.afterMessageReceivedCallbacks.RunCallbacks(p.node); len(errs) > 0 {
log.Warn().Errs("errors", errs).Msg("Got errors running AfterMessageReceived callbacks.")

p.Disconnect()
p.DisconnectAsync()
continue
}
}
Expand Down
64 changes: 64 additions & 0 deletions peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,70 @@ func TestPeer(t *testing.T) {
p.Disconnect()
}

// Test dialler the peer's connection
func TestPeerConnDisconnected(t *testing.T) {
log.Disable()
defer log.Enable()

resetOpcodes()
opcode := RegisterMessage(NextAvailableOpcode(), (*testMsg)(nil))

var port uint16 = 8888
var err error

var wgListen sync.WaitGroup
wgListen.Add(1)

layer := transport.NewBuffered()

go func() {
params := DefaultParams()
params.Keys = ed25519.RandomKeys()
params.Host = "127.0.0.1"
params.Port = port
params.Transport = layer

node, err := NewNode(params)
assert.Nil(t, err)
wgListen.Done()

node.OnPeerConnected(func(node *Node, peer *Peer) error {
<- peer.Receive(opcode)

peer.Disconnect()
return nil
})

node.Listen()
}()

wgListen.Wait()
conn, err := layer.Dial(fmt.Sprintf("%s:%d", "127.0.0.1", port))
assert.NoError(t, err)

p := peer(t, layer, conn, port)

p.OnConnError(func(node *Node, peer *Peer, err error) error {
assert.FailNow(t, "OnConnError should never be called")
return nil
})

var wgDisconnect sync.WaitGroup
wgDisconnect.Add(1)

p.OnDisconnect(func(node *Node, peer *Peer) error {
wgDisconnect.Done()
return nil
})

p.init()

err = p.SendMessage(testMsg{Text: "hello"})
assert.NoError(t, err)

wgDisconnect.Wait()
}

// check the state equal to the expected state, and then increment it by 1
func check(t *testing.T, currentState *int32, expectedState int32) {
assert.Equal(t, expectedState, atomic.LoadInt32(currentState))
Expand Down
2 changes: 1 addition & 1 deletion protocol/mod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ func TestProtocol(t *testing.T) {
assert.Equal(t, atomic.LoadUint32(aliceCount), uint32(10))
assert.Equal(t, atomic.LoadUint32(bobCount), uint32(6))

assert.Equal(t, atomic.LoadUint32(&aliceDisconnected), uint32(0))
assert.Equal(t, atomic.LoadUint32(&aliceDisconnected), uint32(1))
}

0 comments on commit f9c1f15

Please sign in to comment.