Skip to content

Commit

Permalink
Prevent unlimited reads.
Browse files Browse the repository at this point in the history
This can mitigate some denial of service scenarios.
  • Loading branch information
aaugustin committed Sep 5, 2021
1 parent add0d46 commit 0a935b8
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 28 deletions.
40 changes: 29 additions & 11 deletions src/websockets/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@
from . import datastructures, exceptions


# Maximum total size of headers is around 256 * 4 KiB = 1 MiB
MAX_HEADERS = 256
MAX_LINE = 4110

# We can use the same limit for the request line and header lines:
# "GET <4096 bytes> HTTP/1.1\r\n" = 4111 bytes
# "Set-Cookie: <4097 bytes>\r\n" = 4111 bytes
# (RFC requires 4096 bytes; for some reason Firefox supports 4097 bytes.)
MAX_LINE = 4111

# Support for HTTP response bodies is intended to read an error message
# returned by a server. It isn't designed to perform large file transfers.
MAX_BODY = 2 ** 20 # 1 MiB


def d(value: bytes) -> str:
Expand Down Expand Up @@ -60,7 +70,7 @@ class Request:
@classmethod
def parse(
cls,
read_line: Callable[[], Generator[None, None, bytes]],
read_line: Callable[[int], Generator[None, None, bytes]],
) -> Generator[None, None, Request]:
"""
Parse a WebSocket handshake request.
Expand Down Expand Up @@ -157,9 +167,9 @@ class Response:
@classmethod
def parse(
cls,
read_line: Callable[[], Generator[None, None, bytes]],
read_line: Callable[[int], Generator[None, None, bytes]],
read_exact: Callable[[int], Generator[None, None, bytes]],
read_to_eof: Callable[[], Generator[None, None, bytes]],
read_to_eof: Callable[[int], Generator[None, None, bytes]],
) -> Generator[None, None, Response]:
"""
Parse a WebSocket handshake response.
Expand Down Expand Up @@ -234,7 +244,16 @@ def parse(
content_length = int(raw_content_length)

if content_length is None:
body = yield from read_to_eof()
try:
body = yield from read_to_eof(MAX_BODY)
except RuntimeError:
raise exceptions.SecurityError(
f"body too large: over {MAX_BODY} bytes"
)
elif content_length > MAX_BODY:
raise exceptions.SecurityError(
f"body too large: {content_length} bytes"
)
else:
body = yield from read_exact(content_length)

Expand All @@ -255,7 +274,7 @@ def serialize(self) -> bytes:


def parse_headers(
read_line: Callable[[], Generator[None, None, bytes]],
read_line: Callable[[int], Generator[None, None, bytes]],
) -> Generator[None, None, datastructures.Headers]:
"""
Parse HTTP headers.
Expand Down Expand Up @@ -306,7 +325,7 @@ def parse_headers(


def parse_line(
read_line: Callable[[], Generator[None, None, bytes]],
read_line: Callable[[int], Generator[None, None, bytes]],
) -> Generator[None, None, bytes]:
"""
Parse a single line.
Expand All @@ -322,10 +341,9 @@ def parse_line(
SecurityError: if the response exceeds a security limit.
"""
# Security: TODO: add a limit here
line = yield from read_line()
# Security: this guarantees header values are small (hard-coded = 4 KiB)
if len(line) > MAX_LINE:
try:
line = yield from read_line(MAX_LINE)
except RuntimeError:
raise exceptions.SecurityError("line too long")
# Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5
if not line.endswith(b"\r\n"):
Expand Down
21 changes: 19 additions & 2 deletions src/websockets/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@ def __init__(self) -> None:
self.buffer = bytearray()
self.eof = False

def read_line(self) -> Generator[None, None, bytes]:
def read_line(self, m: int) -> Generator[None, None, bytes]:
"""
Read a LF-terminated line from the stream.
This is a generator-based coroutine.
The return value includes the LF character.
Args:
m: maximum number bytes to read; this is a security limit.
Raises:
EOFError: if the stream ends without a LF.
RuntimeError: if the stream ends in more than ``m`` bytes.
"""
n = 0 # number of bytes to read
Expand All @@ -36,9 +40,13 @@ def read_line(self) -> Generator[None, None, bytes]:
if n > 0:
break
p = len(self.buffer)
if p > m:
raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
if self.eof:
raise EOFError(f"stream ends after {p} bytes, before end of line")
yield
if n > m:
raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes")
r = self.buffer[:n]
del self.buffer[:n]
return r
Expand Down Expand Up @@ -66,14 +74,23 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]:
del self.buffer[:n]
return r

def read_to_eof(self) -> Generator[None, None, bytes]:
def read_to_eof(self, m: int) -> Generator[None, None, bytes]:
"""
Read all bytes from the stream.
This is a generator-based coroutine.
Args:
m: maximum number bytes to read; this is a security limit.
Raises:
RuntimeError: if the stream ends in more than ``m`` bytes.
"""
while not self.eof:
p = len(self.buffer)
if p > m:
raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
yield
r = self.buffer[:]
del self.buffer[:]
Expand Down
25 changes: 22 additions & 3 deletions tests/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def test_parse_empty(self):
with self.assertRaises(EOFError) as raised:
next(self.parse())
self.assertEqual(
str(raised.exception), "connection closed while reading HTTP status line"
str(raised.exception),
"connection closed while reading HTTP status line",
)

def test_parse_invalid_status_line(self):
Expand Down Expand Up @@ -230,6 +231,24 @@ def test_parse_body_without_content_length(self):
response = self.assertGeneratorReturns(gen)
self.assertEqual(response.body, b"Hello world!\n")

def test_parse_body_with_content_length_too_long(self):
self.reader.feed_data(b"HTTP/1.1 200 OK\r\nContent-Length: 1048577\r\n\r\n")
with self.assertRaises(SecurityError) as raised:
next(self.parse())
self.assertEqual(
str(raised.exception),
"body too large: 1048577 bytes",
)

def test_parse_body_without_content_length_too_long(self):
self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577)
with self.assertRaises(SecurityError) as raised:
next(self.parse())
self.assertEqual(
str(raised.exception),
"body too large: over 1048576 bytes",
)

def test_parse_body_with_transfer_encoding(self):
self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n")
with self.assertRaises(NotImplementedError) as raised:
Expand Down Expand Up @@ -314,8 +333,8 @@ def test_parse_too_long_value(self):
next(self.parse_headers())

def test_parse_too_long_line(self):
# Header line contains 5 + 4104 + 2 = 4111 bytes.
self.reader.feed_data(b"foo: " + b"a" * 4104 + b"\r\n\r\n")
# Header line contains 5 + 4105 + 2 = 4112 bytes.
self.reader.feed_data(b"foo: " + b"a" * 4105 + b"\r\n\r\n")
with self.assertRaises(SecurityError):
next(self.parse_headers())

Expand Down
65 changes: 53 additions & 12 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,24 @@ def setUp(self):
def test_read_line(self):
self.reader.feed_data(b"spam\neggs\n")

gen = self.reader.read_line()
gen = self.reader.read_line(32)
line = self.assertGeneratorReturns(gen)
self.assertEqual(line, b"spam\n")

gen = self.reader.read_line()
gen = self.reader.read_line(32)
line = self.assertGeneratorReturns(gen)
self.assertEqual(line, b"eggs\n")

def test_read_line_need_more_data(self):
self.reader.feed_data(b"spa")

gen = self.reader.read_line()
gen = self.reader.read_line(32)
self.assertGeneratorRunning(gen)
self.reader.feed_data(b"m\neg")
line = self.assertGeneratorReturns(gen)
self.assertEqual(line, b"spam\n")

gen = self.reader.read_line()
gen = self.reader.read_line(32)
self.assertGeneratorRunning(gen)
self.reader.feed_data(b"gs\n")
line = self.assertGeneratorReturns(gen)
Expand All @@ -37,11 +37,34 @@ def test_read_line_not_enough_data(self):
self.reader.feed_data(b"spa")
self.reader.feed_eof()

gen = self.reader.read_line()
gen = self.reader.read_line(32)
with self.assertRaises(EOFError) as raised:
next(gen)
self.assertEqual(
str(raised.exception), "stream ends after 3 bytes, before end of line"
str(raised.exception),
"stream ends after 3 bytes, before end of line",
)

def test_read_line_too_long(self):
self.reader.feed_data(b"spam\neggs\n")

gen = self.reader.read_line(2)
with self.assertRaises(RuntimeError) as raised:
next(gen)
self.assertEqual(
str(raised.exception),
"read 5 bytes, expected no more than 2 bytes",
)

def test_read_line_too_long_need_more_data(self):
self.reader.feed_data(b"spa")

gen = self.reader.read_line(2)
with self.assertRaises(RuntimeError) as raised:
next(gen)
self.assertEqual(
str(raised.exception),
"read 3 bytes, expected no more than 2 bytes",
)

def test_read_exact(self):
Expand Down Expand Up @@ -78,11 +101,12 @@ def test_read_exact_not_enough_data(self):
with self.assertRaises(EOFError) as raised:
next(gen)
self.assertEqual(
str(raised.exception), "stream ends after 3 bytes, expected 4 bytes"
str(raised.exception),
"stream ends after 3 bytes, expected 4 bytes",
)

def test_read_to_eof(self):
gen = self.reader.read_to_eof()
gen = self.reader.read_to_eof(32)

self.reader.feed_data(b"spam")
self.assertGeneratorRunning(gen)
Expand All @@ -94,10 +118,21 @@ def test_read_to_eof(self):
def test_read_to_eof_at_eof(self):
self.reader.feed_eof()

gen = self.reader.read_to_eof()
gen = self.reader.read_to_eof(32)
data = self.assertGeneratorReturns(gen)
self.assertEqual(data, b"")

def test_read_to_eof_too_long(self):
gen = self.reader.read_to_eof(2)

self.reader.feed_data(b"spam")
with self.assertRaises(RuntimeError) as raised:
next(gen)
self.assertEqual(
str(raised.exception),
"read 4 bytes, expected no more than 2 bytes",
)

def test_at_eof_after_feed_data(self):
gen = self.reader.at_eof()
self.assertGeneratorRunning(gen)
Expand Down Expand Up @@ -137,16 +172,22 @@ def test_feed_data_after_feed_eof(self):
self.reader.feed_eof()
with self.assertRaises(EOFError) as raised:
self.reader.feed_data(b"spam")
self.assertEqual(str(raised.exception), "stream ended")
self.assertEqual(
str(raised.exception),
"stream ended",
)

def test_feed_eof_after_feed_eof(self):
self.reader.feed_eof()
with self.assertRaises(EOFError) as raised:
self.reader.feed_eof()
self.assertEqual(str(raised.exception), "stream ended")
self.assertEqual(
str(raised.exception),
"stream ended",
)

def test_discard(self):
gen = self.reader.read_to_eof()
gen = self.reader.read_to_eof(32)

self.reader.feed_data(b"spam")
self.reader.discard()
Expand Down

0 comments on commit 0a935b8

Please sign in to comment.