Skip to content

Commit

Permalink
parse Redis commands in the mock server and shutdown server on failure
Browse files Browse the repository at this point in the history
  • Loading branch information
woutdenolf committed Mar 22, 2023
1 parent 87cf9b1 commit 50123cc
Showing 1 changed file with 123 additions and 65 deletions.
188 changes: 123 additions & 65 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
import socket
import ssl
import threading
Expand All @@ -12,6 +13,43 @@


_CLIENT_NAME = "test-suite-client"
_CMD_SEP = b"\r\n"
_SUCCESS_RESP = b"+OK" + _CMD_SEP
_ERROR_RESP = b"-ERR" + _CMD_SEP
_COMMANDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


@pytest.fixture
def ssl_cert(tcp_address, tmpdir):
"""More or less equivalent to
.. code::
openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem
"""
host, _ = tcp_address
ca = trustme.CA()
cert = ca.issue_cert(host, common_name="trustme")

server_pem = str(tmpdir / "server.pem")
cert.private_key_and_cert_chain_pem.write_to_path(path=server_pem)

client_pem = str(tmpdir / "client.pem")
ca.cert_pem.write_to_path(path=client_pem)

return client_pem, server_pem


def test_tcp_connect(tcp_address):
Expand All @@ -35,7 +73,25 @@ def test_tcp_ssl_connect(tcp_address, ssl_cert):
_assert_connect(conn, tcp_address, certfile=server_pem)


def redis_mock_server(server_address, ready, commands, certfile=None):
def _assert_connect(conn, server_address, certfile=None):
ready = threading.Event()
stop = threading.Event()
t = threading.Thread(
target=_redis_mock_server,
args=(server_address, ready, stop),
kwargs={"certfile": certfile},
)
t.start()
try:
ready.wait()
conn.connect()
conn.disconnect()
finally:
stop.set()
t.join(timeout=5)


def _redis_mock_server(server_address, ready, stop, certfile=None):
try:
if isinstance(server_address, str):
family = socket.AF_UNIX
Expand All @@ -46,86 +102,88 @@ def redis_mock_server(server_address, ready, commands, certfile=None):
else:
family = socket.AF_INET
mockname = "Redis mock server (TCP)"

with socket.socket(family, socket.SOCK_STREAM) as s:
s.bind(server_address)
s.listen(1)
s.settimeout(0.1)

if certfile:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.load_cert_chain(certfile=certfile)

_logger.info("Start %s: %s", mockname, server_address)
ready.set()
ssock, _ = s.accept()
with ssock:

# Wait a client connection
while not stop.is_set():
try:
sconn, _ = s.accept()
sconn.settimeout(0.1)
break
except socket.timeout:
pass
if stop.is_set():
_logger.info("Exit %s: %s", mockname, server_address)
return

# Receive commands from the client
with sconn:
if certfile:
conn = context.wrap_socket(ssock, server_side=True)
conn = context.wrap_socket(sconn, server_side=True)
else:
conn = ssock
conn = sconn
try:
while True:
data = conn.recv(1024)
if not data:
_logger.info("Exit %s: %s", mockname, server_address)
break
_logger.info("Command in %s: %s", mockname, data)
resp = b"+ERROR\r\n"
resp = commands.get(data, resp)
_logger.info("Response from %s: %s", mockname, resp)
conn.sendall(resp)
buffer = b""
command = None
command_ptr = None
fragment_length = None
while not stop.is_set() or buffer:
try:
buffer += conn.recv(1024)
except socket.timeout:
continue
if not buffer:
continue
parts = re.split(_CMD_SEP, buffer)
buffer = parts[-1]
for fragment in parts[:-1]:
fragment = fragment.decode()
_logger.info(
"Command fragment in %s: %s", mockname, fragment
)

if fragment.startswith("*") and command is None:
command = [None for _ in range(int(fragment[1:]))]
command_ptr = 0
fragment_length = None
continue

if (
fragment.startswith("$")
and command[command_ptr] is None
):
fragment_length = int(fragment[1:])
continue

assert len(fragment) == fragment_length
command[command_ptr] = fragment
command_ptr += 1

if command_ptr < len(command):
continue

command = " ".join(command)
_logger.info("Command in %s: %s", mockname, command)
resp = _COMMANDS.get(command, _ERROR_RESP)
_logger.info("Response from %s: %s", mockname, resp)
conn.sendall(resp)
command = None
finally:
if certfile:
conn.close()
_logger.info("Exit %s: %s", mockname, server_address)
except BaseException as e:
_logger.exception("Error in %s: %s", mockname, e)
raise


def _assert_connect(conn, server_address, **server_kwargs):
command = conn.pack_command("CLIENT", "SETNAME", _CLIENT_NAME)[0]
commands = {command: b"+OK\r\n"}

ready = threading.Event()
t = threading.Thread(
target=redis_mock_server,
args=(server_address, ready, commands),
kwargs=server_kwargs,
)
t.start()
ready.wait()
conn.connect()
conn.disconnect()
t.join()


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


@pytest.fixture
def ssl_cert(tcp_address, tmpdir):
"""More or less equivalent to
.. code::
openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem
"""
host, _ = tcp_address
ca = trustme.CA()
cert = ca.issue_cert(host, common_name="trustme")

server_pem = str(tmpdir / "server.pem")
cert.private_key_and_cert_chain_pem.write_to_path(path=server_pem)

client_pem = str(tmpdir / "client.pem")
ca.cert_pem.write_to_path(path=client_pem)

return client_pem, server_pem

0 comments on commit 50123cc

Please sign in to comment.