Skip to content

Fix error when a shared event loop is unset #1180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/1172.added.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Pyright type checking support in tox configuration to improve type safety and compatibility.
1 change: 1 addition & 0 deletions changelog.d/1177.fixed.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``RuntimeError: There is no current event loop in thread 'MainThread'`` when using shared event loops after any test unsets the event loop (such as when using ``asyncio.run`` and ``asyncio.Runner``).
111 changes: 98 additions & 13 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@
PytestPluginManager,
)

_seen_markers: set[int] = set()


def _warn_scope_deprecation_once(marker_id: int) -> None:
"""Issues deprecation warning exactly once per marker ID."""
if marker_id not in _seen_markers:
_seen_markers.add(marker_id)
warnings.warn(PytestDeprecationWarning(_MARKER_SCOPE_KWARG_DEPRECATION_WARNING))


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
Expand All @@ -62,7 +72,9 @@
_ScopeName = Literal["session", "package", "module", "class", "function"]
_R = TypeVar("_R", bound=Union[Awaitable[Any], AsyncIterator[Any]])
_P = ParamSpec("_P")
T = TypeVar("T")
FixtureFunction = Callable[_P, _R]
CoroutineFunction = Callable[_P, Awaitable[T]]


class PytestAsyncioError(Exception):
Expand Down Expand Up @@ -291,7 +303,7 @@ def _asyncgen_fixture_wrapper(
gen_obj = fixture_function(*args, **kwargs)

async def setup():
res = await gen_obj.__anext__() # type: ignore[union-attr]
res = await gen_obj.__anext__()
return res

context = contextvars.copy_context()
Expand All @@ -304,7 +316,7 @@ def finalizer() -> None:

async def async_finalizer() -> None:
try:
await gen_obj.__anext__() # type: ignore[union-attr]
await gen_obj.__anext__()
except StopAsyncIteration:
pass
else:
Expand Down Expand Up @@ -333,8 +345,7 @@ def _wrap_async_fixture(
runner: Runner,
request: FixtureRequest,
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:

@functools.wraps(fixture_function) # type: ignore[arg-type]
@functools.wraps(fixture_function)
def _async_fixture_wrapper(
*args: AsyncFixtureParams.args,
**kwargs: AsyncFixtureParams.kwargs,
Expand Down Expand Up @@ -447,7 +458,7 @@ def _can_substitute(item: Function) -> bool:
return inspect.iscoroutinefunction(func)

def runtest(self) -> None:
synchronized_obj = wrap_in_sync(self.obj)
synchronized_obj = get_async_test_wrapper(self, self.obj)
with MonkeyPatch.context() as c:
c.setattr(self, "obj", synchronized_obj)
super().runtest()
Expand Down Expand Up @@ -489,7 +500,7 @@ def _can_substitute(item: Function) -> bool:
)

def runtest(self) -> None:
synchronized_obj = wrap_in_sync(self.obj)
synchronized_obj = get_async_test_wrapper(self, self.obj)
with MonkeyPatch.context() as c:
c.setattr(self, "obj", synchronized_obj)
super().runtest()
Expand All @@ -511,7 +522,10 @@ def _can_substitute(item: Function) -> bool:
)

def runtest(self) -> None:
synchronized_obj = wrap_in_sync(self.obj.hypothesis.inner_test)
synchronized_obj = get_async_test_wrapper(
self,
self.obj.hypothesis.inner_test,
)
with MonkeyPatch.context() as c:
c.setattr(self.obj.hypothesis, "inner_test", synchronized_obj)
super().runtest()
Expand Down Expand Up @@ -602,6 +616,62 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
asyncio.set_event_loop(loop)


_session_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
contextvars.ContextVar(
"_session_loop",
default=None,
)
)
_package_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
contextvars.ContextVar(
"_package_loop",
default=None,
)
)
_module_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
contextvars.ContextVar(
"_module_loop",
default=None,
)
)
_class_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
contextvars.ContextVar(
"_class_loop",
default=None,
)
)
_function_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
contextvars.ContextVar(
"_function_loop",
default=None,
)
)

_SCOPE_TO_CONTEXTVAR = {
"session": _session_loop,
"package": _package_loop,
"module": _module_loop,
"class": _class_loop,
"function": _function_loop,
}


def _get_or_restore_event_loop(loop_scope: _ScopeName) -> asyncio.AbstractEventLoop:
"""
Get or restore the appropriate event loop for the given scope.
If we have a shared loop for this scope, restore and return it.
Otherwise, get the current event loop or create a new one.
"""
shared_loop = _SCOPE_TO_CONTEXTVAR[loop_scope].get()
if shared_loop is not None:
policy = _get_event_loop_policy()
policy.set_event_loop(shared_loop)
return shared_loop
else:
return _get_event_loop_no_warn()


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
"""
Expand Down Expand Up @@ -652,9 +722,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
return None


def wrap_in_sync(
func: Callable[..., Awaitable[Any]],
):
def get_async_test_wrapper(
item: Function,
func: CoroutineFunction[_P, T],
) -> Callable[_P, None]:
"""Returns a synchronous wrapper for the specified async test function."""
marker = item.get_closest_marker("asyncio")
assert marker is not None
default_loop_scope = _get_default_test_loop_scope(item.config)
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
return _wrap_in_sync(func, loop_scope)


def _wrap_in_sync(
func: CoroutineFunction[_P, T],
loop_scope: _ScopeName,
) -> Callable[_P, None]:
"""
Return a sync wrapper around an async function executing it in the
current event loop.
Expand All @@ -663,7 +746,7 @@ def wrap_in_sync(
@functools.wraps(func)
def inner(*args, **kwargs):
coro = func(*args, **kwargs)
_loop = _get_event_loop_no_warn()
_loop = _get_or_restore_event_loop(loop_scope)
task = asyncio.ensure_future(coro, loop=_loop)
try:
_loop.run_until_complete(task)
Expand Down Expand Up @@ -746,7 +829,7 @@ def _get_marked_loop_scope(
if "scope" in asyncio_marker.kwargs:
if "loop_scope" in asyncio_marker.kwargs:
raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR)
warnings.warn(PytestDeprecationWarning(_MARKER_SCOPE_KWARG_DEPRECATION_WARNING))
_warn_scope_deprecation_once(id(asyncio_marker))
scope = asyncio_marker.kwargs.get("loop_scope") or asyncio_marker.kwargs.get(
"scope"
)
Expand All @@ -756,7 +839,7 @@ def _get_marked_loop_scope(
return scope


def _get_default_test_loop_scope(config: Config) -> _ScopeName:
def _get_default_test_loop_scope(config: Config) -> Any:
return config.getini("asyncio_default_test_loop_scope")


Expand Down Expand Up @@ -784,6 +867,8 @@ def _scoped_runner(
debug_mode = _get_asyncio_debug(request.config)
with _temporary_event_loop_policy(new_loop_policy):
runner = Runner(debug=debug_mode).__enter__()
shared_loop = runner.get_loop()
_SCOPE_TO_CONTEXTVAR[scope].set(shared_loop)
try:
yield runner
except Exception as e:
Expand Down
Loading
Loading