diff --git a/docs/database.rst b/docs/database.rst index fcdd219a..75666ad3 100644 --- a/docs/database.rst +++ b/docs/database.rst @@ -60,6 +60,80 @@ select using an argument to the ``django_db`` mark:: def test_spam(): pass # test relying on transactions + +Async tests and database transactions +------------------------------------- + +``pytest-django`` supports async tests that use Django's async ORM APIs. +This requires the `pytest-asyncio `_ +plugin and marking your tests appropriately. + +Requirements +------------ + +- Install ``pytest-asyncio``. +- Mark async tests with both ``@pytest.mark.asyncio`` and + ``@pytest.mark.django_db`` (or request the ``db``/``transactional_db`` fixtures). + +Example (async ORM with transactional rollback per test):: + + import pytest + + @pytest.mark.asyncio + @pytest.mark.django_db + async def test_async_db_is_isolated(): + assert await Item.objects.acount() == 0 + await Item.objects.acreate(name="example") + assert await Item.objects.acount() == 1 + # changes are rolled back after the test + +.. _`async-db-behavior`: + +Behavior of ``db`` in async tests +--------------------------------- + +Tests using ``db`` wrap each test in a transaction and roll that transaction back at the end +(like ``django.test.TestCase``). In Django, transactions are bound to the database +connection, which is unique per thread. This means that all your database changes +must be made within the same thread to ensure they are rolled back before the next test. + +Django Async ORM calls, as of writing, use the ``asgiref.sync.sync_to_async`` +decorator to run the ORM calls on a dedicated thread executor. + +For async tests, pytest-django ensures the transaction +setup/teardown happens via ``asgiref.sync.sync_to_async``, which means the transaction is started & run on the +same thread on which async orm calls inside your test, like ``aget()`` are made. This ensures your test code +can safely modify the database using the async calls, as all its queries will be rolled back after the test. + +Tests using ``transactional_db`` flush the database between tests. This means that no matter in which thread +your test modifies the database, the changes will be removed after the test. This means you can avoid thinking +about sync/async database access if your test uses ``transactional_db``, at the cost of slower tests: +A flush is generally slower than rolling back a transaction. + +.. _`db-thread-safeguards`: + +Safeguards against database access from different threads +--------------------------------------------------------- +When using the database in a test with transaction rollback, you must ensure that +database access is only done from the same thread that the test is running on. + +To avoid your fixtures/tests making changes outside the test thread, and as a result, the transaction, pytest-django +actively restricts where database connections may be opened in async tests: + +- In async tests using ``db``: database access is only allowed from the single + thread used by ``SyncToAsync``. Using sync fixtures that touch the database in + an async test will raise:: + + RuntimeError: Database access is only allowed in an async context, modify your + test fixtures to be async or use the transactional_db fixture. + + Fix by converting those fixtures to async (use ``pytest_asyncio.fixture``) and + using Django's async ORM methods (e.g. ``.acreate()``, ``.aget()``, ``.acount()``), + or by requesting ``transactional_db`` if you must keep sync fixtures. + See :ref:`async-db-behavior` for more details. + + + .. _`multi-db`: Tests requiring multiple databases @@ -524,3 +598,4 @@ Put this in ``conftest.py``:: django_db_blocker.unblock() yield django_db_blocker.restore() + diff --git a/pyproject.toml b/pyproject.toml index 75915cc8..b8d2b794 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,10 @@ coverage = [ "coverage[toml]", "coverage-enable-subprocess", ] +async = [ + "asgiref>=3.9.1", + "pytest-asyncio", +] postgres = [ "psycopg[binary]", ] diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 6f7929be..5f719971 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: + from collections.abc import AsyncGenerator from typing import Any, Callable, Literal, Optional, Union import django @@ -203,8 +204,49 @@ def django_db_setup( ) +def _build_pytest_django_test_case( + test_case_class: type[django.test.TestCase], + *, + reset_sequences: bool, + serialized_rollback: bool, + databases: _DjangoDbDatabases, + available_apps: _DjangoDbAvailableApps, + skip_django_testcase_class_setup: bool, +) -> type[django.test.TestCase]: + # Build a custom TestCase subclass with configured attributes and optional + # overrides to skip Django's TestCase class-level setup/teardown. + import django.test # local import to avoid hard dependency at import time + + _reset_sequences = reset_sequences + _serialized_rollback = serialized_rollback + _databases = databases + _available_apps = available_apps + + class PytestDjangoTestCase(test_case_class): + reset_sequences = _reset_sequences + serialized_rollback = _serialized_rollback + if _databases is not None: + databases = _databases + if _available_apps is not None: + available_apps = _available_apps + + if skip_django_testcase_class_setup: + + @classmethod + def setUpClass(cls) -> None: + # Skip django.test.TestCase.setUpClass, call its super instead + super(django.test.TestCase, cls).setUpClass() + + @classmethod + def tearDownClass(cls) -> None: + # Skip django.test.TestCase.tearDownClass, call its super instead + super(django.test.TestCase, cls).tearDownClass() + + return PytestDjangoTestCase + + @pytest.fixture -def _django_db_helper( +def _sync_django_db_helper( request: pytest.FixtureRequest, django_db_setup: None, # noqa: ARG001 django_db_blocker: DjangoDbBlocker, @@ -250,41 +292,14 @@ def _django_db_helper( else: test_case_class = django.test.TestCase - _reset_sequences = reset_sequences - _serialized_rollback = serialized_rollback - _databases = databases - _available_apps = available_apps - - class PytestDjangoTestCase(test_case_class): # type: ignore[misc,valid-type] - reset_sequences = _reset_sequences - serialized_rollback = _serialized_rollback - if _databases is not None: - databases = _databases - if _available_apps is not None: - available_apps = _available_apps - - # For non-transactional tests, skip executing `django.test.TestCase`'s - # `setUpClass`/`tearDownClass`, only execute the super class ones. - # - # `TestCase`'s class setup manages the `setUpTestData`/class-level - # transaction functionality. We don't use it; instead we (will) offer - # our own alternatives. So it only adds overhead, and does some things - # which conflict with our (planned) functionality, particularly, it - # closes all database connections in `tearDownClass` which inhibits - # wrapping tests in higher-scoped transactions. - # - # It's possible a new version of Django will add some unrelated - # functionality to these methods, in which case skipping them completely - # would not be desirable. Let's cross that bridge when we get there... - if not transactional: - - @classmethod - def setUpClass(cls) -> None: - super(django.test.TestCase, cls).setUpClass() - - @classmethod - def tearDownClass(cls) -> None: - super(django.test.TestCase, cls).tearDownClass() + PytestDjangoTestCase = _build_pytest_django_test_case( + test_case_class, + reset_sequences=reset_sequences, + serialized_rollback=serialized_rollback, + databases=databases, + available_apps=available_apps, + skip_django_testcase_class_setup=(not transactional), + ) PytestDjangoTestCase.setUpClass() @@ -300,6 +315,112 @@ def tearDownClass(cls) -> None: PytestDjangoTestCase.doClassCleanups() +try: + import pytest_asyncio +except ImportError: + + async def _async_django_db_helper( + request: pytest.FixtureRequest, # noqa: ARG001 + django_db_blocker: DjangoDbBlocker, # noqa: ARG001 + ) -> AsyncGenerator[None, None]: + raise RuntimeError( + "The `pytest_asyncio` plugin is required to use the `async_django_db` fixture." + ) + yield # pragma: no cover +else: + + @pytest_asyncio.fixture + async def _async_django_db_helper( + request: pytest.FixtureRequest, + django_db_blocker: DjangoDbBlocker, + ) -> AsyncGenerator[None, None]: + # same as _sync_django_db_helper, except for running the transaction start and rollback wrapped in a + # `sync_to_async` call + transactional, reset_sequences, databases, serialized_rollback, available_apps = ( + _get_django_db_settings(request) + ) + + with django_db_blocker.unblock(async_only=True): + import django.db + import django.test + + test_case_class = django.test.TestCase + + PytestDjangoTestCase = _build_pytest_django_test_case( + test_case_class, + reset_sequences=reset_sequences, + serialized_rollback=serialized_rollback, + databases=databases, + available_apps=available_apps, + skip_django_testcase_class_setup=True, + ) + + from asgiref.sync import sync_to_async + + await sync_to_async(PytestDjangoTestCase.setUpClass)() + + test_case = PytestDjangoTestCase(methodName="__init__") + await sync_to_async(test_case._pre_setup, thread_sensitive=True)() + + yield + + await sync_to_async(test_case._post_teardown, thread_sensitive=True)() + + await sync_to_async(PytestDjangoTestCase.tearDownClass)() + + await sync_to_async(PytestDjangoTestCase.doClassCleanups)() + + +def _get_django_db_settings(request: pytest.FixtureRequest) -> _DjangoDb: + django_marker = request.node.get_closest_marker("django_db") + if django_marker: + ( + transactional, + reset_sequences, + databases, + serialized_rollback, + available_apps, + ) = validate_django_db(django_marker) + else: + ( + transactional, + reset_sequences, + databases, + serialized_rollback, + available_apps, + ) = False, False, None, False, None + + transactional = ( + transactional + or reset_sequences + or ("transactional_db" in request.fixturenames or "live_server" in request.fixturenames) + ) + + reset_sequences = reset_sequences or ("django_db_reset_sequences" in request.fixturenames) + serialized_rollback = serialized_rollback or ( + "django_db_serialized_rollback" in request.fixturenames + ) + return transactional, reset_sequences, databases, serialized_rollback, available_apps + + +@pytest.fixture +def _django_db_helper( + request: pytest.FixtureRequest, + django_db_setup: None, # noqa: ARG001 + django_db_blocker: DjangoDbBlocker, # noqa: ARG001 +) -> None: + asyncio_marker = request.node.get_closest_marker("asyncio") + transactional, *_ = _get_django_db_settings(request) + if transactional or not asyncio_marker: + # add the original sync fixture + request.getfixturevalue("_sync_django_db_helper") + else: + # add the async fixture. Will run it inside the event loop, which will cause the sync to async calls to + # start a transaction on the thread safe executor for that loop. This allows us to roll back orm calls made + # in that async test context. + request.getfixturevalue("_async_django_db_helper") + + def _django_db_signature( transaction: bool = False, reset_sequences: bool = False, diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 314fb856..88071b13 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -11,6 +11,7 @@ import os import pathlib import sys +import threading import types from collections.abc import Generator from contextlib import AbstractContextManager @@ -21,8 +22,10 @@ from .django_compat import is_django_unittest from .fixtures import ( + _async_django_db_helper, # noqa: F401 _django_db_helper, # noqa: F401 _live_server_helper, # noqa: F401 + _sync_django_db_helper, # noqa: F401 admin_client, # noqa: F401 admin_user, # noqa: F401 async_client, # noqa: F401 @@ -54,7 +57,7 @@ if TYPE_CHECKING: - from typing import Any, NoReturn + from typing import Any, Callable, NoReturn import django @@ -817,7 +820,7 @@ def __init__(self, *, _ispytest: bool = False) -> None: ) self._history = [] # type: ignore[var-annotated] - self._real_ensure_connection = None + self._real_ensure_connection: None | Callable[[Any], Any] = None @property def _dj_db_wrapper(self) -> django.db.backends.base.base.BaseDatabaseWrapper: @@ -833,18 +836,41 @@ def _dj_db_wrapper(self) -> django.db.backends.base.base.BaseDatabaseWrapper: def _save_active_wrapper(self) -> None: self._history.append(self._dj_db_wrapper.ensure_connection) - def _blocking_wrapper(*args: Any, **kwargs: Any) -> NoReturn: # noqa: ARG002 + def _blocking_wrapper(self, *args: Any, **kwargs: Any) -> NoReturn: # noqa: ARG002 __tracebackhide__ = True raise RuntimeError( "Database access not allowed, " 'use the "django_db" mark, or the ' - '"db" or "transactional_db" fixtures to enable it.' + '"db" or "transactional_db" fixtures to enable it. ' ) - def unblock(self) -> AbstractContextManager[None]: + def _unblocked_async_only(self, wrapper_self: Any, *args: Any, **kwargs: Any) -> None: + __tracebackhide__ = True + from asgiref.sync import SyncToAsync + + is_in_sync_to_async_thread = ( + next(iter(SyncToAsync.single_thread_executor._threads)) == threading.current_thread() + ) + if not is_in_sync_to_async_thread: + raise RuntimeError( + "Database access is only allowed in an async context, " + "modify your test fixtures to be async or use the transactional_db fixture." + "See https://pytest-django.readthedocs.io/en/latest/database.html#db-thread-safeguards for more information." + ) + if self._real_ensure_connection is not None: + self._real_ensure_connection(wrapper_self, *args, **kwargs) + + def unblock(self, async_only: bool = False) -> AbstractContextManager[None]: """Enable access to the Django database.""" self._save_active_wrapper() - self._dj_db_wrapper.ensure_connection = self._real_ensure_connection + if async_only: + + def _method(wrapper_self: Any, *args: Any, **kwargs: Any) -> None: + return self._unblocked_async_only(wrapper_self, *args, **kwargs) + + self._dj_db_wrapper.ensure_connection = _method + else: + self._dj_db_wrapper.ensure_connection = self._real_ensure_connection return _DatabaseBlockerContextManager(self) def block(self) -> AbstractContextManager[None]: diff --git a/tests/test_async_db.py b/tests/test_async_db.py new file mode 100644 index 00000000..e0b2a75e --- /dev/null +++ b/tests/test_async_db.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import Any, cast + +import pytest +from _pytest.mark import MarkDecorator + +from pytest_django_test.app.models import Item + + +try: + import pytest_asyncio +except ImportError: + pytestmark: MarkDecorator = pytest.mark.skip("pytest-asyncio is not installed") + fixturemark: MarkDecorator = pytest.mark.skip("pytest-asyncio is not installed") + +else: + pytestmark = pytest.mark.asyncio + fixturemark = cast(MarkDecorator, pytest_asyncio.fixture) + + +@pytest.mark.parametrize("run_number", [1, 2]) +@pytestmark +@pytest.mark.django_db +async def test_async_db(run_number: int) -> None: # noqa: ARG001 + # test async database usage remains isolated between tests + + assert await Item.objects.acount() == 0 + # make a new item instance, to be rolled back by the transaction wrapper before the next parametrized run + await Item.objects.acreate(name="blah") + assert await Item.objects.acount() == 1 + + +@fixturemark +async def db_item() -> Any: + return await Item.objects.acreate(name="async") + + +@pytest.fixture +def sync_db_item() -> Any: + return Item.objects.create(name="sync") + + +@pytest.mark.usefixtures("db_item", "sync_db_item") +@pytestmark +@pytest.mark.xfail(strict=True, reason="Sync fixture used in async test") +@pytest.mark.django_db +async def test_db_item() -> None: + pass diff --git a/tox.ini b/tox.ini index 5ffeeead..e03a46f8 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ envlist = [testenv] dependency_groups = testing + !dj42: async coverage: coverage mysql: mysql postgres: postgres @@ -43,7 +44,9 @@ commands = coverage: coverage xml [testenv:linting] -dependency_groups = linting +dependency_groups = + linting + async commands = ruff check {posargs:pytest_django pytest_django_test tests} ruff format --quiet --diff {posargs:pytest_django pytest_django_test tests}