diff --git a/cylc/flow/dbstatecheck.py b/cylc/flow/dbstatecheck.py index 1fae3e0feb..b38c394ccd 100644 --- a/cylc/flow/dbstatecheck.py +++ b/cylc/flow/dbstatecheck.py @@ -72,7 +72,7 @@ def __init__(self, rund, workflow, db_path=None): if not os.path.exists(db_path): raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), db_path) - self.conn = sqlite3.connect(db_path, timeout=10.0) + self.conn: sqlite3.Connection = sqlite3.connect(db_path, timeout=10.0) # Get workflow point format. try: @@ -84,8 +84,17 @@ def __init__(self, rund, workflow, db_path=None): self.db_point_fmt = self._get_db_point_format_compat() self.c7_back_compat_mode = True except sqlite3.OperationalError: + with suppress(Exception): + self.conn.close() raise exc # original error + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Close DB connection when leaving context manager.""" + self.conn.close() + def adjust_point_to_db(self, cycle, offset): """Adjust a cycle point (with offset) to the DB point format. diff --git a/tests/integration/test_dbstatecheck.py b/tests/integration/test_dbstatecheck.py index 33db2ec4c2..08fd59aa0e 100644 --- a/tests/integration/test_dbstatecheck.py +++ b/tests/integration/test_dbstatecheck.py @@ -52,13 +52,14 @@ async def checker( schd: Scheduler = mod_scheduler(wid, paused_start=False) async with mod_run(schd): await mod_complete(schd) - schd.pool.force_trigger_tasks(['1000/good'], [2]) + schd.pool.force_trigger_tasks(['1000/good'], ['2']) # Allow a cycle of the main loop to pass so that flow 2 can be # added to db await sleep(1) - yield CylcWorkflowDBChecker( + with CylcWorkflowDBChecker( 'somestring', 'utterbunkum', schd.workflow_db_mgr.pub_path - ) + ) as _checker: + yield _checker def test_basic(checker): diff --git a/tests/unit/test_db_compat.py b/tests/unit/test_db_compat.py index 5393e85e67..0a04bea0b1 100644 --- a/tests/unit/test_db_compat.py +++ b/tests/unit/test_db_compat.py @@ -129,14 +129,13 @@ def test_cylc_7_db_wflow_params_table(_setup_db): rf'("cycle_point_format", "{ptformat}")' ) db_file_name = _setup_db([create, insert]) - checker = CylcWorkflowDBChecker('foo', 'bar', db_path=db_file_name) + with CylcWorkflowDBChecker('foo', 'bar', db_path=db_file_name) as checker: + with pytest.raises( + sqlite3.OperationalError, match="no such table: workflow_params" + ): + checker._get_db_point_format() - with pytest.raises( - sqlite3.OperationalError, match="no such table: workflow_params" - ): - checker._get_db_point_format() - - assert checker.db_point_fmt == ptformat + assert checker.db_point_fmt == ptformat def test_pre_830_task_action_timers(_setup_db):