Skip to content

Commit

Permalink
Add support for custom timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
AudriusButkevicius committed Jun 2, 2016
1 parent 7b7a891 commit 5d7eea0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
17 changes: 13 additions & 4 deletions natpmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package natpmp
import (
"fmt"
"net"
"time"
)

// Implement the NAT-PMP protocol, typically supported by Apple routers and open source
Expand All @@ -20,17 +21,25 @@ const RECOMMENDED_MAPPING_LIFETIME_SECONDS = 3600

// Interface used to make remote procedure calls.
type caller interface {
call(msg []byte) (result []byte, err error)
call(msg []byte, timeout time.Duration) (result []byte, err error)
}

// Client is a NAT-PMP protocol client.
type Client struct {
caller caller
caller caller
timeout time.Duration
}

// Create a NAT-PMP client for the NAT-PMP server at the gateway.
// Uses default timeout which is around 128 seconds.
func NewClient(gateway net.IP) (nat *Client) {
return &Client{&network{gateway}}
return &Client{&network{gateway}, 0}
}

// Create a NAT-PMP client for the NAT-PMP server at the gateway, with a timeout.
// Timeout defines the total amount of time we will keep retrying before giving up.
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) (nat *Client) {
return &Client{&network{gateway}, timeout}
}

// Results of the NAT-PMP GetExternalAddress operation.
Expand Down Expand Up @@ -92,7 +101,7 @@ func (n *Client) AddPortMapping(protocol string, internalPort, requestedExternal
}

func (n *Client) rpc(msg []byte, resultSize int) (result []byte, err error) {
result, err = n.caller.call(msg)
result, err = n.caller.call(msg, n.timeout)
if err != nil {
return
}
Expand Down
9 changes: 5 additions & 4 deletions natpmp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"testing"
"time"
)

type callRecord struct {
Expand All @@ -19,7 +20,7 @@ type mockNetwork struct {
cr callRecord
}

func (n *mockNetwork) call(msg []byte) (result []byte, err error) {
func (n *mockNetwork) call(msg []byte, timeout time.Duration) (result []byte, err error) {
if bytes.Compare(msg, n.cr.msg) != 0 {
n.t.Errorf("msg=%v, expected %v", msg, n.cr.msg)
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestGetExternalAddress(t *testing.T) {
}
for i, testCase := range testCases {
t.Logf("case %d", i)
c := Client{&mockNetwork{t, testCase.cr}}
c := Client{&mockNetwork{t, testCase.cr}, 0}
result, err := c.GetExternalAddress()
if err != nil {
if err != testCase.err {
Expand Down Expand Up @@ -163,7 +164,7 @@ func TestAddPortMapping(t *testing.T) {

for i, testCase := range testCases {
t.Logf("case %d", i)
c := Client{&mockNetwork{t, testCase.cr}}
c := Client{&mockNetwork{t, testCase.cr}, 0}
result, err := c.AddPortMapping(testCase.protocol, testCase.internalPort, testCase.requestedExternalPort, testCase.lifetime)
if err != nil || testCase.err != nil {
if err != testCase.err && fmt.Sprintf("%v", err) != fmt.Sprintf("%v", testCase.err) {
Expand Down Expand Up @@ -235,7 +236,7 @@ func TestProtocolChecks(t *testing.T) {
}
for i, testCase := range testCases {
t.Logf("case %d", i)
c := Client{&mockNetwork{t, testCase.cr}}
c := Client{&mockNetwork{t, testCase.cr}, 0}
result, err := c.AddPortMapping(testCase.protocol, testCase.internalPort, testCase.requestedExternalPort, testCase.lifetime)
if err != testCase.err && fmt.Sprintf("%v", err) != fmt.Sprintf("%v", testCase.err) {
t.Errorf("err=%v != %v", err, testCase.err)
Expand Down
25 changes: 22 additions & 3 deletions network.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type network struct {
gateway net.IP
}

func (n *network) call(msg []byte) (result []byte, err error) {
func (n *network) call(msg []byte, timeout time.Duration) (result []byte, err error) {
var server net.UDPAddr
server.IP = n.gateway
server.Port = nAT_PMP_PORT
Expand All @@ -28,12 +28,18 @@ func (n *network) call(msg []byte) (result []byte, err error) {
// 16 bytes is the maximum result size.
result = make([]byte, 16)

var finalTimeout time.Time
if timeout != 0 {
finalTimeout = time.Now().Add(timeout)
}

needNewDeadline := true

var tries uint
for tries = 0; tries < nAT_TRIES; {
for tries = 0; (tries < nAT_TRIES && finalTimeout.IsZero()) || time.Now().Before(finalTimeout); {
if needNewDeadline {
err = conn.SetDeadline(time.Now().Add((nAT_INITIAL_MS << tries) * time.Millisecond))
nextDeadline := time.Now().Add((nAT_INITIAL_MS << tries) * time.Millisecond)
err = conn.SetDeadline(minTime(nextDeadline, finalTimeout))
if err != nil {
return
}
Expand Down Expand Up @@ -68,3 +74,16 @@ func (n *network) call(msg []byte) (result []byte, err error) {
err = fmt.Errorf("Timed out trying to contact gateway")
return
}

func minTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() {
return a
}
if a.Before(b) {
return a
}
return b
}
6 changes: 4 additions & 2 deletions recorder.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package natpmp

import "time"

type callObserver interface {
observeCall(msg []byte, result []byte, err error)
}
Expand All @@ -10,8 +12,8 @@ type recorder struct {
observer callObserver
}

func (n *recorder) call(msg []byte) (result []byte, err error) {
result, err = n.child.call(msg)
func (n *recorder) call(msg []byte, timeout time.Duration) (result []byte, err error) {
result, err = n.child.call(msg, timeout)
n.observer.observeCall(msg, result, err)
return
}

0 comments on commit 5d7eea0

Please sign in to comment.