Skip to content

Commit

Permalink
Fixed #33277 -- Disallowed database connections in threads in SimpleT…
Browse files Browse the repository at this point in the history
…estCase.
  • Loading branch information
David-Wobrock authored and felixxm committed Jan 3, 2024
1 parent 45f778e commit 8fb0be3
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
31 changes: 31 additions & 0 deletions django/test/testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from copy import copy, deepcopy
from difflib import get_close_matches
from functools import wraps
from unittest import mock
from unittest.suite import _DebugResult
from unittest.util import safe_repr
from urllib.parse import (
Expand Down Expand Up @@ -37,6 +38,7 @@
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
from django.core.signals import setting_changed
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
from django.forms.fields import CharField
from django.http import QueryDict
from django.http.request import split_domain_port, validate_host
Expand Down Expand Up @@ -255,6 +257,13 @@ def _add_databases_failures(cls):
}
method = getattr(connection, name)
setattr(connection, name, _DatabaseFailure(method, message))
cls.enterClassContext(
mock.patch.object(
BaseDatabaseWrapper,
"ensure_connection",
new=cls.ensure_connection_patch_method(),
)
)

@classmethod
def _remove_databases_failures(cls):
Expand All @@ -266,6 +275,28 @@ def _remove_databases_failures(cls):
method = getattr(connection, name)
setattr(connection, name, method.wrapped)

@classmethod
def ensure_connection_patch_method(cls):
real_ensure_connection = BaseDatabaseWrapper.ensure_connection

def patched_ensure_connection(self, *args, **kwargs):
if (
self.connection is None
and self.alias not in cls.databases
and self.alias != NO_DB_ALIAS
):
# Connection has not yet been established, but the alias is not allowed.
message = cls._disallowed_database_msg % {
"test": f"{cls.__module__}.{cls.__qualname__}",
"alias": self.alias,
"operation": "threaded connections",
}
return _DatabaseFailure(self.ensure_connection, message)()

real_ensure_connection(self, *args, **kwargs)

return patched_ensure_connection

def __call__(self, result=None):
"""
Wrapper around default __call__ method to perform common Django test
Expand Down
3 changes: 3 additions & 0 deletions docs/releases/5.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ Tests
* The new :meth:`.SimpleTestCase.assertNotInHTML` assertion allows testing that
an HTML fragment is not contained in the given HTML haystack.

* In order to enforce test isolation, database connections inside threads are
no longer allowed in :class:`~django.test.SimpleTestCase`.

URLs
~~~~

Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import threading
import unittest
import warnings
from io import StringIO
Expand Down Expand Up @@ -2093,6 +2094,29 @@ def test_disallowed_database_chunked_cursor_queries(self):
with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
next(Car.objects.iterator())

def test_disallowed_thread_database_connection(self):
expected_message = (
"Database threaded connections to 'default' are not allowed in "
"SimpleTestCase subclasses. Either subclass TestCase or TransactionTestCase"
" to ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)

exceptions = []

def thread_func():
try:
Car.objects.first()
except DatabaseOperationForbidden as e:
exceptions.append(e)

t = threading.Thread(target=thread_func)
t.start()
t.join()
self.assertEqual(len(exceptions), 1)
self.assertEqual(exceptions[0].args[0], expected_message)


class AllowedDatabaseQueriesTests(SimpleTestCase):
databases = {"default"}
Expand All @@ -2103,6 +2127,14 @@ def test_allowed_database_queries(self):
def test_allowed_database_chunked_cursor_queries(self):
next(Car.objects.iterator(), None)

def test_allowed_threaded_database_queries(self):
def thread_func():
next(Car.objects.iterator(), None)

t = threading.Thread(target=thread_func)
t.start()
t.join()


class DatabaseAliasTests(SimpleTestCase):
def setUp(self):
Expand Down

0 comments on commit 8fb0be3

Please sign in to comment.