Skip to content

Commit

Permalink
Merge pull request #50 from arduino/fix_crashing_disc_handling
Browse files Browse the repository at this point in the history
Fix panic when dealing with crashing discoveries
  • Loading branch information
cmaglie authored Sep 18, 2024
2 parents 6ae82f5 + cc38790 commit fbec7bc
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
29 changes: 12 additions & 17 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ func (disc *Client) jsonDecodeLoop(in io.Reader, outChan chan<- *discoveryMessag
closeAndReportError := func(err error) {
disc.statusMutex.Lock()
disc.incomingMessagesError = err
disc.statusMutex.Unlock()
disc.stopSync()
disc.killProcess()
disc.statusMutex.Unlock()
close(outChan)
if err != nil {
disc.logger.Errorf("Stopped decode loop: %v", err)
Expand All @@ -138,11 +138,7 @@ func (disc *Client) jsonDecodeLoop(in io.Reader, outChan chan<- *discoveryMessag

for {
var msg discoveryMessage
if err := decoder.Decode(&msg); errors.Is(err, io.EOF) {
// This is fine :flames: we exit gracefully
closeAndReportError(nil)
return
} else if err != nil {
if err := decoder.Decode(&msg); err != nil {
closeAndReportError(err)
return
}
Expand Down Expand Up @@ -184,7 +180,10 @@ func (disc *Client) waitMessage(timeout time.Duration) (*discoveryMessage, error
select {
case msg := <-disc.incomingMessagesChan:
if msg == nil {
return nil, disc.incomingMessagesError
disc.statusMutex.Lock()
err := disc.incomingMessagesError
disc.statusMutex.Unlock()
return nil, err
}
return msg, nil
case <-time.After(timeout):
Expand Down Expand Up @@ -239,9 +238,6 @@ func (disc *Client) runProcess() error {
}

func (disc *Client) killProcess() {
disc.statusMutex.Lock()
defer disc.statusMutex.Unlock()

disc.logger.Debugf("Killing discovery process")
if process := disc.process; process != nil {
disc.process = nil
Expand Down Expand Up @@ -270,7 +266,9 @@ func (disc *Client) Run() (err error) {
if err == nil {
return
}
disc.statusMutex.Lock()
disc.killProcess()
disc.statusMutex.Unlock()
}()

if err = disc.sendCommand("HELLO 1 \"arduino-cli " + disc.userAgent + "\"\n"); err != nil {
Expand All @@ -287,8 +285,6 @@ func (disc *Client) Run() (err error) {
} else if msg.ProtocolVersion > 1 {
return fmt.Errorf("protocol version not supported: requested 1, got %d", msg.ProtocolVersion)
}
disc.statusMutex.Lock()
defer disc.statusMutex.Unlock()
return nil
}

Expand All @@ -307,8 +303,6 @@ func (disc *Client) Start() error {
} else if strings.ToUpper(msg.Message) != "OK" {
return fmt.Errorf("communication out of sync, expected 'OK', received '%s'", msg.Message)
}
disc.statusMutex.Lock()
defer disc.statusMutex.Unlock()
return nil
}

Expand Down Expand Up @@ -348,8 +342,10 @@ func (disc *Client) Quit() {
if _, err := disc.waitMessage(time.Second * 5); err != nil {
disc.logger.Errorf("Quitting discovery: %s", err)
}
disc.statusMutex.Lock()
disc.stopSync()
disc.killProcess()
disc.statusMutex.Unlock()
}

// List executes an enumeration of the ports and returns a list of the available
Expand Down Expand Up @@ -377,9 +373,6 @@ func (disc *Client) List() ([]*Port, error) {
// The event channel must be consumed as quickly as possible since it may block the
// discovery if it becomes full. The channel size is configurable.
func (disc *Client) StartSync(size int) (<-chan *Event, error) {
disc.statusMutex.Lock()
defer disc.statusMutex.Unlock()

if err := disc.sendCommand("START_SYNC\n"); err != nil {
return nil, err
}
Expand All @@ -395,6 +388,8 @@ func (disc *Client) StartSync(size int) (<-chan *Event, error) {
}

// In case there is already an existing event channel in use we close it before creating a new one.
disc.statusMutex.Lock()
defer disc.statusMutex.Unlock()
disc.stopSync()
c := make(chan *Event, size)
disc.eventChan = c
Expand Down
56 changes: 56 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package discovery

import (
"fmt"
"io"
"net"
"testing"
"time"
Expand Down Expand Up @@ -93,3 +94,58 @@ func TestDiscoveryStdioHandling(t *testing.T) {

require.False(t, disc.Alive())
}

func TestClient(t *testing.T) {
// Build dummy-discovery
builder, err := paths.NewProcess(nil, "go", "build")
require.NoError(t, err)
builder.SetDir("dummy-discovery")
require.NoError(t, builder.Run())

t.Run("WithDiscoveryCrashingOnStartup", func(t *testing.T) {
// Run client with discovery crashing on startup
cl := NewClient("1", "dummy-discovery/dummy-discovery", "--invalid")
require.ErrorIs(t, cl.Run(), io.EOF)
})

t.Run("WithDiscoveryCrashingWhileSendingCommands", func(t *testing.T) {
// Run client with crashing discovery after 1 second
cl := NewClient("1", "dummy-discovery/dummy-discovery", "-k")
require.NoError(t, cl.Run())

time.Sleep(time.Second)

ch, err := cl.StartSync(20)
require.Error(t, err)
require.Nil(t, ch)
})

t.Run("WithDiscoveryCrashingWhileStreamingEvents", func(t *testing.T) {
// Run client with crashing discovery after 1 second
cl := NewClient("1", "dummy-discovery/dummy-discovery", "-k")
require.NoError(t, cl.Run())

ch, err := cl.StartSync(20)
require.NoError(t, err)

time.Sleep(time.Second)

loop:
for {
select {
case msg, ok := <-ch:
if !ok {
// Channel closed: Test passed
fmt.Println("Event channel closed")
break loop
}
fmt.Println("Recv: ", msg)
case <-time.After(time.Second):
t.Error("Crashing client did not close event channel")
break loop
}
}

cl.Quit()
})
}
9 changes: 9 additions & 0 deletions dummy-discovery/args/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package args
import (
"fmt"
"os"
"time"
)

// Tag is the current git tag
Expand All @@ -38,6 +39,14 @@ func Parse() {
fmt.Printf("dummy-discovery %s (build timestamp: %s)\n", Tag, Timestamp)
os.Exit(0)
}
if arg == "-k" {
// Emulate crashing discovery
go func() {
time.Sleep(time.Millisecond * 500)
os.Exit(1)
}()
continue
}
fmt.Fprintf(os.Stderr, "invalid argument: %s\n", arg)
os.Exit(1)
}
Expand Down

0 comments on commit fbec7bc

Please sign in to comment.