Skip to content

Commit

Permalink
add 'connect' tests for all Redis connection classes
Browse files Browse the repository at this point in the history
  • Loading branch information
woutdenolf committed Mar 22, 2023
1 parent b167df0 commit f11ddae
Showing 1 changed file with 176 additions and 0 deletions.
176 changes: 176 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import logging
import re
import socket
import ssl
import threading

import pytest

from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection

from .ssl_certificates import get_ssl_certificate

_logger = logging.getLogger(__name__)


_CLIENT_NAME = "test-suite-client"
_CMD_SEP = b"\r\n"
_SUCCESS_RESP = b"+OK" + _CMD_SEP
_ERROR_RESP = b"-ERR" + _CMD_SEP
_COMMANDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


def test_tcp_connect(tcp_address):
host, port = tcp_address
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME)
_assert_connect(conn, tcp_address)


def test_uds_connect(uds_address):
path = str(uds_address)
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME)
_assert_connect(conn, path)


@pytest.mark.ssl
def test_tcp_ssl_connect(tcp_address):
host, port = tcp_address
certfile = get_ssl_certificate("server-cert.pem")
keyfile = get_ssl_certificate("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_certfile=certfile,
ssl_keyfile=keyfile,
ssl_ca_certs=certfile,
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


def _assert_connect(conn, server_address, certfile=None, keyfile=None):
ready = threading.Event()
stop = threading.Event()
t = threading.Thread(
target=_redis_mock_server,
args=(server_address, ready, stop),
kwargs={"certfile": certfile, "keyfile": keyfile},
)
t.start()
try:
ready.wait()
conn.connect()
conn.disconnect()
finally:
stop.set()
t.join(timeout=5)


def _redis_mock_server(server_address, ready, stop, certfile=None, keyfile=None):
try:
if isinstance(server_address, str):
family = socket.AF_UNIX
mockname = "Redis mock server (UDS)"
elif certfile:
family = socket.AF_INET
mockname = "Redis mock server (TCP-SSL)"
else:
family = socket.AF_INET
mockname = "Redis mock server (TCP)"

with socket.socket(family, socket.SOCK_STREAM) as s:
s.bind(server_address)
s.listen(1)
s.settimeout(0.1)

if certfile:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.load_cert_chain(certfile=certfile, keyfile=keyfile)

_logger.info("Start %s: %s", mockname, server_address)
ready.set()

# Wait a client connection
while not stop.is_set():
try:
sconn, _ = s.accept()
sconn.settimeout(0.1)
break
except socket.timeout:
pass
if stop.is_set():
_logger.info("Exit %s: %s", mockname, server_address)
return

# Receive commands from the client
with sconn:
if certfile:
conn = context.wrap_socket(sconn, server_side=True)
else:
conn = sconn
try:
buffer = b""
command = None
command_ptr = None
fragment_length = None
while not stop.is_set() or buffer:
try:
buffer += conn.recv(1024)
except socket.timeout:
continue
if not buffer:
continue
parts = re.split(_CMD_SEP, buffer)
buffer = parts[-1]
for fragment in parts[:-1]:
fragment = fragment.decode()
_logger.info(
"Command fragment in %s: %s", mockname, fragment
)

if fragment.startswith("*") and command is None:
command = [None for _ in range(int(fragment[1:]))]
command_ptr = 0
fragment_length = None
continue

if (
fragment.startswith("$")
and command[command_ptr] is None
):
fragment_length = int(fragment[1:])
continue

assert len(fragment) == fragment_length
command[command_ptr] = fragment
command_ptr += 1

if command_ptr < len(command):
continue

command = " ".join(command)
_logger.info("Command in %s: %s", mockname, command)
resp = _COMMANDS.get(command, _ERROR_RESP)
_logger.info("Response from %s: %s", mockname, resp)
conn.sendall(resp)
command = None
finally:
if certfile:
conn.close()
_logger.info("Exit %s: %s", mockname, server_address)
except BaseException as e:
_logger.exception("Error in %s: %s", mockname, e)
raise

0 comments on commit f11ddae

Please sign in to comment.