diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b67975c..a9513ac 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,6 +34,8 @@ jobs: uses: actions/checkout@v4 - name: Install requirements run: python -m pip install -r requirements.txt + - name: Install backport of unittest.mock + run: python -m pip install mock - name: Run tests run: python -m django test tracdjangoplugin.tests env: diff --git a/DjangoPlugin/tracdjangoplugin/middlewares.py b/DjangoPlugin/tracdjangoplugin/middlewares.py new file mode 100644 index 0000000..0f18681 --- /dev/null +++ b/DjangoPlugin/tracdjangoplugin/middlewares.py @@ -0,0 +1,22 @@ +from django.core.signals import request_finished, request_started + + +class DjangoDBManagementMiddleware: + """ + A simple WSGI middleware that manually manages opening/closing db connections. + + Django normally does that as part of its own middleware chain, but we're using Trac's middleware + so we must do this by hand. + This hopefully prevents open connections from piling up. + """ + + def __init__(self, application): + self.application = application + + def __call__(self, environ, start_response): + request_started.send(sender=self.__class__) + try: + for data in self.application(environ, start_response): + yield data + finally: + request_finished.send(sender=self.__class__) diff --git a/DjangoPlugin/tracdjangoplugin/tests.py b/DjangoPlugin/tracdjangoplugin/tests.py index 0731e2a..bacd529 100644 --- a/DjangoPlugin/tracdjangoplugin/tests.py +++ b/DjangoPlugin/tracdjangoplugin/tests.py @@ -1,13 +1,20 @@ from functools import partial +try: + from unittest.mock import Mock +except ImportError: + from mock import Mock + +from django.core.signals import request_finished, request_started from django.contrib.auth.forms import AuthenticationForm from django.contrib.auth.models import User -from django.test import TestCase +from django.test import SimpleTestCase, TestCase from trac.test import EnvironmentStub, MockRequest from trac.web.api import RequestDone from trac.web.main import RequestDispatcher +from tracdjangoplugin.middlewares import DjangoDBManagementMiddleware from tracdjangoplugin.plugins import PlainLoginComponent @@ -127,3 +134,53 @@ def test_login_invalid_username_uppercased(self): def test_login_invalid_inactive_user(self): User.objects.create_user(username="test", password="test", is_active=False) self.assertLoginFails(username="test", password="test") + + +class DjangoDBManagementMiddlewareTestCase(SimpleTestCase): + @classmethod + def setUpClass(cls): + # Remove receivers from the request_started and request_finished signals, + # replacing them with a mock object so we can still check if they were called. + super(DjangoDBManagementMiddlewareTestCase, cls).setUpClass() + cls._original_signal_receivers = {} + cls.signals = {} + for signal in [request_started, request_finished]: + cls.signals[signal] = Mock() + cls._original_signal_receivers[signal] = signal.receivers + signal.receivers = [] + signal.connect(cls.signals[signal]) + + @classmethod + def tearDownClass(cls): + # Restore the signals we modified in setUpClass() to what they were before + super(DjangoDBManagementMiddlewareTestCase, cls).tearDownClass() + for signal, original_receivers in cls._original_signal_receivers.items(): + # messing about with receivers directly is not an official API, so we need to + # call some undocumented methods to make sure caches and such are taken care of. + with signal.lock: + signal.receivers = original_receivers + signal._clear_dead_receivers() + signal.sender_receivers_cache.clear() + + def setUp(self): + super(DjangoDBManagementMiddlewareTestCase, self).setUp() + for mockobj in self.signals.values(): + mockobj.reset_mock() + + def test_request_start_fired(self): + app = DjangoDBManagementMiddleware(lambda environ, start_response: [b"test"]) + output = b"".join(app(None, None)) + self.assertEqual(output, b"test") + self.signals[request_started].assert_called_once() + + def test_request_finished_fired(self): + app = DjangoDBManagementMiddleware(lambda environ, start_response: [b"test"]) + output = b"".join(app(None, None)) + self.assertEqual(output, b"test") + self.signals[request_finished].assert_called_once() + + def test_request_finished_fired_even_with_error(self): + app = DjangoDBManagementMiddleware(lambda environ, start_response: [1 / 0]) + with self.assertRaises(ZeroDivisionError): + list(app(None, None)) + self.signals[request_finished].assert_called_once() diff --git a/DjangoPlugin/tracdjangoplugin/wsgi.py b/DjangoPlugin/tracdjangoplugin/wsgi.py index ab5403e..6f233a4 100644 --- a/DjangoPlugin/tracdjangoplugin/wsgi.py +++ b/DjangoPlugin/tracdjangoplugin/wsgi.py @@ -12,6 +12,11 @@ # Python 3 would perform better here, but we are still on 2.7 for Trac, so leak fds for now. from tracopt.versioncontrol.git import PyGIT +from .middlewares import DjangoDBManagementMiddleware + + +application = DjangoDBManagementMiddleware(application) + PyGIT.close_fds = False trac_dsn = os.getenv("SENTRY_DSN")