Skip to content

Commit

Permalink
fix: ensure websocket conns respect max duration (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
mccutchen authored Nov 30, 2023
1 parent 1c61db6 commit e0324b1
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 0 deletions.
1 change: 1 addition & 0 deletions httpbin/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) {
}

ws := websocket.New(w, r, websocket.Limits{
MaxDuration: h.MaxDuration,
MaxFragmentSize: int(maxFragmentSize),
MaxMessageSize: int(maxMessageSize),
})
Expand Down
8 changes: 8 additions & 0 deletions httpbin/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io"
"net/http"
"strings"
"time"
"unicode/utf8"
)

Expand Down Expand Up @@ -80,6 +81,7 @@ var EchoHandler Handler = func(ctx context.Context, msg *Message) (*Message, err

// Limits define the limits imposed on a websocket connection.
type Limits struct {
MaxDuration time.Duration
MaxFragmentSize int
MaxMessageSize int
}
Expand All @@ -88,6 +90,7 @@ type Limits struct {
type WebSocket struct {
w http.ResponseWriter
r *http.Request
maxDuration time.Duration
maxFragmentSize int
maxMessageSize int
handshook bool
Expand All @@ -98,6 +101,7 @@ func New(w http.ResponseWriter, r *http.Request, limits Limits) *WebSocket {
return &WebSocket{
w: w,
r: r,
maxDuration: limits.MaxDuration,
maxFragmentSize: limits.MaxFragmentSize,
maxMessageSize: limits.MaxMessageSize,
}
Expand Down Expand Up @@ -152,6 +156,10 @@ func (s *WebSocket) Serve(handler Handler) {
}
defer conn.Close()

// best effort attempt to ensure that our websocket conenctions do not
// exceed the maximum request duration
conn.SetDeadline(time.Now().Add(s.maxDuration))

// errors intentionally ignored here. it's serverLoop's responsibility to
// properly close the websocket connection with a useful error message, and
// any unexpected error returned from serverLoop is not actionable.
Expand Down
1 change: 1 addition & 0 deletions httpbin/websocket/websocket_autobahn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func TestWebSocketServer(t *testing.T) {

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws := websocket.New(w, r, websocket.Limits{
MaxDuration: 30 * time.Second,
MaxFragmentSize: 1024 * 1024 * 16,
MaxMessageSize: 1024 * 1024 * 16,
})
Expand Down
152 changes: 152 additions & 0 deletions httpbin/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ package websocket_test
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
"github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
Expand Down Expand Up @@ -220,6 +225,153 @@ func TestHandshakeOrder(t *testing.T) {
})
}

func TestConnectionLimits(t *testing.T) {
t.Run("maximum request duration is enforced", func(t *testing.T) {
t.Parallel()

maxDuration := 500 * time.Millisecond

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws := websocket.New(w, r, websocket.Limits{
MaxDuration: maxDuration,
// TODO: test these limits as well
MaxFragmentSize: 128,
MaxMessageSize: 256,
})
if err := ws.Handshake(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ws.Serve(websocket.EchoHandler)
}))
defer srv.Close()

conn, err := net.Dial("tcp", srv.Listener.Addr().String())
assert.NilError(t, err)
defer conn.Close()

reqParts := []string{
"GET /websocket/echo HTTP/1.1",
"Host: test",
"Connection: upgrade",
"Upgrade: websocket",
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version: 13",
}
reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n")
t.Logf("raw request:\n%q", reqBytes)

// first, we write the request line and headers, which should cause the
// server to respond with a 101 Switching Protocols response.
{
n, err := conn.Write(reqBytes)
assert.NilError(t, err)
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")

resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
assert.NilError(t, err)
assert.StatusCode(t, resp, http.StatusSwitchingProtocols)
}

// next, we try to read from the connection, expecting the connection
// to be closed after roughly maxDuration seconds
{
start := time.Now()
_, err := conn.Read(make([]byte, 1))
elapsed := time.Since(start)

assert.Error(t, err, io.EOF)
assert.RoughDuration(t, elapsed, maxDuration, 25*time.Millisecond)
}
})

t.Run("client closing connection", func(t *testing.T) {
t.Parallel()

// the client will close the connection well before the server closes
// the connection. make sure the server properly handles the client
// closure.
var (
clientTimeout = 100 * time.Millisecond
serverTimeout = time.Hour // should never be reached
elapsedClientTime time.Duration
elapsedServerTime time.Duration
wg sync.WaitGroup
)

wg.Add(1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer wg.Done()
start := time.Now()
ws := websocket.New(w, r, websocket.Limits{
MaxDuration: serverTimeout,
MaxFragmentSize: 128,
MaxMessageSize: 256,
})
if err := ws.Handshake(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ws.Serve(websocket.EchoHandler)
elapsedServerTime = time.Since(start)
}))
defer srv.Close()

conn, err := net.Dial("tcp", srv.Listener.Addr().String())
assert.NilError(t, err)
defer conn.Close()

// should cause the client end of the connection to close well before
// the max request time configured above
conn.SetDeadline(time.Now().Add(clientTimeout))

reqParts := []string{
"GET /websocket/echo HTTP/1.1",
"Host: test",
"Connection: upgrade",
"Upgrade: websocket",
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version: 13",
}
reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n")
t.Logf("raw request:\n%q", reqBytes)

// first, we write the request line and headers, which should cause the
// server to respond with a 101 Switching Protocols response.
{
n, err := conn.Write(reqBytes)
assert.NilError(t, err)
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")

resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
assert.NilError(t, err)
assert.StatusCode(t, resp, http.StatusSwitchingProtocols)
}

// next, we try to read from the connection, expecting the connection
// to be closed after roughly clientTimeout seconds.
//
// the server should detect the closed connection and abort the
// handler, also after roughly clientTimeout seconds.
{
start := time.Now()
_, err := conn.Read(make([]byte, 1))
elapsedClientTime = time.Since(start)

// close client connection, which should interrupt the server's
// blocking read call on the connection
conn.Close()

assert.Equal(t, os.IsTimeout(err), true, "expected timeout error")
assert.RoughDuration(t, elapsedClientTime, clientTimeout, 10*time.Millisecond)

// wait for the server to finish
wg.Wait()
assert.RoughDuration(t, elapsedServerTime, clientTimeout, 10*time.Millisecond)
}
})
}

// brokenHijackResponseWriter implements just enough to satisfy the
// http.ResponseWriter and http.Hijacker interfaces and get through the
// handshake before failing to actually hijack the connection.
Expand Down

0 comments on commit e0324b1

Please sign in to comment.