Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add database config class #6481

Closed
wants to merge 10 commits into from
39 changes: 20 additions & 19 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,29 @@ def make_homeserver(self, reactor, clock):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))

datastores = Mock()
datastores.main = Mock(
spec=[
# Bits that Federation needs
"prep_send_transaction",
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_devices_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
"get_device_updates_by_remote",
]
)

hs = self.setup_test_homeserver(
datastore=(
Mock(
spec=[
# Bits that Federation needs
"prep_send_transaction",
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_device_updates_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
]
)
),
notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
)

hs.datastores = datastores

return hs

def prepare(self, reactor, clock, hs):
Expand Down
3 changes: 2 additions & 1 deletion tests/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def make_homeserver(self, reactor, clock):

def prepare(self, reactor, clock, hs):

db_config = hs.config.database.get_single_database()
self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage()
self.slaved_store = self.STORE_TYPE(
Database(hs), self.hs.get_db_conn(), self.hs
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
Database(hs, db_config), db_config.get_pool(reactor).connect(), self.hs
)
self.event_id = 0

Expand Down
11 changes: 5 additions & 6 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,13 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
Set up a synchronous test server, driven by the reactor used by
the homeserver.
"""
d = _sth(cleanup_func, *args, **kwargs).result
server = _sth(cleanup_func, *args, **kwargs)

if isinstance(d, Failure):
d.raiseException()
database = server.config.database.get_single_database()

# Make the thread pool synchronous.
clock = d.get_clock()
pool = d.get_db_pool()
clock = server.get_clock()
pool = database.get_pool(clock._reactor)

def runWithConnection(func, *args, **kwargs):
return threads.deferToThreadPool(
Expand All @@ -336,7 +335,7 @@ def runInteraction(interaction, *args, **kwargs):
pool.runInteraction = runInteraction
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
return d
return server


def get_clock():
Expand Down
24 changes: 15 additions & 9 deletions tests/storage/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def setUp(self):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
database = Database(hs)
self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
database = hs.config.get_single_database()
self.store = ApplicationServiceStore(database, database.make_conn(), hs)

def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
Expand Down Expand Up @@ -111,8 +111,9 @@ def setUp(self):
hs.config.event_cache_size = 1
hs.config.password_providers = []

self.db_pool = hs.get_db_pool()
self.engine = hs.database_engine
database = hs.config.get_single_database()
self.db_pool = database.get_pool(hs.get_reactor())
self.engine = database.engine

self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
Expand All @@ -125,8 +126,10 @@ def setUp(self):

self.as_yaml_files = []

database = Database(hs)
self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
db_config = hs.config.get_single_database()
self.store = TestTransactionStore(
Database(hs, db_config), db_config.make_conn(), hs
)

def _add_service(self, url, as_token, id):
as_yaml = dict(
Expand Down Expand Up @@ -419,7 +422,8 @@ def test_unique_works(self):
hs.config.event_cache_size = 1
hs.config.password_providers = []

ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
database = hs.config.get_single_database()
ApplicationServiceStore(database, database.make_conn(), hs)

@defer.inlineCallbacks
def test_duplicate_ids(self):
Expand All @@ -435,7 +439,8 @@ def test_duplicate_ids(self):
hs.config.password_providers = []

with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
database = hs.config.get_single_database()
ApplicationServiceStore(database, database.make_conn(), hs)

e = cm.exception
self.assertIn(f1, str(e))
Expand All @@ -456,7 +461,8 @@ def test_duplicate_as_tokens(self):
hs.config.password_providers = []

with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
database = hs.config.get_single_database()
ApplicationServiceStore(database, database.make_conn(), hs)

e = cm.exception
self.assertIn(f1, str(e))
Expand Down
10 changes: 6 additions & 4 deletions tests/storage/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ def runWithConnection(func, *args, **kwargs):
engine = create_engine(config.database_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
hs = TestHomeServer(
"test", db_pool=self.db_pool, config=config, database_engine=fake_engine
)
hs = TestHomeServer("test", config=config)

mock_db = Mock()
mock_db.engine = fake_engine
mock_db.get_pool.return_value = self.db_pool

self.datastore = SQLBaseStore(Database(hs), None, hs)
self.datastore = SQLBaseStore(Database(Mock(), mock_db), None, hs)

@defer.inlineCallbacks
def test_insert_1col(self):
Expand Down
1 change: 0 additions & 1 deletion tests/storage/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.db_pool = hs.get_db_pool()

self.store = hs.get_datastore()

Expand Down
38 changes: 9 additions & 29 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
Expand Down Expand Up @@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore


@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func,
name="test",
Expand Down Expand Up @@ -214,7 +214,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex

config.database_config = {
database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
Expand All @@ -226,12 +226,15 @@ def setup_test_homeserver(
},
}
else:
config.database_config = {
database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}

db_engine = create_engine(config.database_config)
database = DatabaseConnectionConfig(database_config)
db_engine = database.engine
config.database.databases = {"main_db": database}
config.database.data_stores = {"main": "main_db", "state": "main_db"}

# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
Expand All @@ -251,39 +254,24 @@ def setup_test_homeserver(
cur.close()
db_conn.close()

# we need to configure the connection pool to run the on_new_connection
# function, so that we can test code that uses custom sqlite functions
# (like rank).
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection

if datastore is None:
hs = homeserverToUse(
name,
config=config,
db_config=config.database_config,
version_string="Synapse/tests",
database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
**kargs
)

# Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
# date db
if not isinstance(db_engine, PostgresEngine):
db_conn = hs.get_db_conn()
yield prepare_database(db_conn, db_engine, config)
db_conn.commit()
db_conn.close()

else:
if isinstance(db_engine, PostgresEngine):
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2

# Close all the db pools
hs.get_db_pool().close()
database.get_pool(reactor).close()

dropped = False

Expand Down Expand Up @@ -326,19 +314,11 @@ def cleanup():
if homeserverToUse.__name__ == "TestHomeServer":
hs.setup_master()
else:
# If we have been given an explicit datastore we probably want to mock
# out the DataStores somehow too. This all feels a bit wrong, but then
# mocking the stores feels wrong too.
datastores = Mock(datastore=datastore)

hs = homeserverToUse(
name,
db_pool=None,
datastore=datastore,
datastores=datastores,
config=config,
version_string="Synapse/tests",
database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
Expand Down