diff --git a/CHANGES/3471.bugfix b/CHANGES/3471.bugfix new file mode 100644 index 00000000000..98e28e68465 --- /dev/null +++ b/CHANGES/3471.bugfix @@ -0,0 +1,2 @@ +Use the same task for app initialization and web server handling in gunicorn workers. +It allows to use Python3.7 context vars smoothly. diff --git a/aiohttp/worker.py b/aiohttp/worker.py index 4681b47eadb..73ba6e38f69 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -6,7 +6,7 @@ import signal import sys from types import FrameType -from typing import Any, Optional # noqa +from typing import Any, Awaitable, Callable, Optional, Union # noqa from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat from gunicorn.workers import base @@ -14,6 +14,7 @@ from aiohttp import web from .helpers import set_result +from .web_app import Application from .web_log import AccessLogger try: @@ -37,7 +38,6 @@ class GunicornWebWorker(base.Worker): def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover super().__init__(*args, **kw) - self._runner = None # type: Optional[web.AppRunner] self._task = None # type: Optional[asyncio.Task[None]] self.exit_code = 0 self._notify_waiter = None # type: Optional[asyncio.Future[bool]] @@ -52,35 +52,39 @@ def init_process(self) -> None: super().init_process() def run(self) -> None: - access_log = self.log.access_log if self.cfg.accesslog else None - params = dict( - logger=self.log, - keepalive_timeout=self.cfg.keepalive, - access_log=access_log, - access_log_format=self._get_valid_log_format( - self.cfg.access_log_format)) - if asyncio.iscoroutinefunction(self.wsgi): # type: ignore - self.wsgi = self.loop.run_until_complete( - self.wsgi()) # type: ignore - self._runner = web.AppRunner(self.wsgi, **params) - self.loop.run_until_complete(self._runner.setup()) self._task = self.loop.create_task(self._run()) try: # ignore all finalization problems self.loop.run_until_complete(self._task) - except Exception as error: - self.log.exception(error) + except Exception: + self.log.exception("Exception in gunicorn worker") if sys.version_info >= (3, 6): - if hasattr(self.loop, 'shutdown_asyncgens'): - self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) self.loop.close() sys.exit(self.exit_code) async def _run(self) -> None: + if isinstance(self.wsgi, Application): + app = self.wsgi + elif asyncio.iscoroutinefunction(self.wsgi): + app = await self.wsgi() + else: + raise RuntimeError("wsgi app should be either Application or " + "async function returning Application, got {}" + .format(self.wsgi)) + access_log = self.log.access_log if self.cfg.accesslog else None + runner = web.AppRunner(app, + logger=self.log, + keepalive_timeout=self.cfg.keepalive, + access_log=access_log, + access_log_format=self._get_valid_log_format( + self.cfg.access_log_format)) + await runner.setup() + ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None - runner = self._runner + runner = runner assert runner is not None server = runner.server assert server is not None diff --git a/tests/test_worker.py b/tests/test_worker.py index d78662e0074..27c75248774 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -9,7 +9,6 @@ import pytest from aiohttp import web -from aiohttp.test_utils import make_mocked_coro base_worker = pytest.importorskip('aiohttp.worker') @@ -42,13 +41,15 @@ def __init__(self): self.wsgi = web.Application() -class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): # type: ignore # noqa +class AsyncioWorker(BaseTestWorker, # type: ignore + base_worker.GunicornWebWorker): pass PARAMS = [AsyncioWorker] if uvloop is not None: - class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): # type: ignore # noqa + class UvloopWorker(BaseTestWorker, # type: ignore + base_worker.GunicornUVLoopWebWorker): pass PARAMS.append(UvloopWorker) @@ -78,12 +79,13 @@ def test_run(worker, loop) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT + worker.cfg.is_ssl = False + worker.sockets = [] worker.loop = loop - worker._run = make_mocked_coro(None) with pytest.raises(SystemExit): worker.run() - assert worker._run.called + worker.log.exception.assert_not_called() assert loop.is_closed() @@ -91,6 +93,8 @@ def test_run_async_factory(worker, loop) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT + worker.cfg.is_ssl = False + worker.sockets = [] app = worker.wsgi async def make_app(): @@ -98,10 +102,24 @@ async def make_app(): worker.wsgi = make_app worker.loop = loop - worker._run = make_mocked_coro(None) + worker.alive = False + with pytest.raises(SystemExit): + worker.run() + worker.log.exception.assert_not_called() + assert loop.is_closed() + + +def test_run_not_app(worker, loop) -> None: + worker.log = mock.Mock() + worker.cfg = mock.Mock() + worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT + + worker.loop = loop + worker.wsgi = "not-app" + worker.alive = False with pytest.raises(SystemExit): worker.run() - assert worker._run.called + worker.log.exception.assert_called_with('Exception in gunicorn worker') assert loop.is_closed() @@ -197,15 +215,11 @@ async def test__run_ok_parent_changed(worker, loop, worker.cfg.max_requests = 0 worker.cfg.is_ssl = False - worker._runner = web.AppRunner(worker.wsgi) - await worker._runner.setup() - await worker._run() worker.notify.assert_called_with() worker.log.info.assert_called_with("Parent changed, shutting down: %s", worker) - assert worker._runner.server is None async def test__run_exc(worker, loop, aiohttp_unused_port) -> None: @@ -223,9 +237,6 @@ async def test__run_exc(worker, loop, aiohttp_unused_port) -> None: worker.cfg.max_requests = 0 worker.cfg.is_ssl = False - worker._runner = web.AppRunner(worker.wsgi) - await worker._runner.setup() - def raiser(): waiter = worker._notify_waiter worker.alive = False @@ -235,37 +246,6 @@ def raiser(): await worker._run() worker.notify.assert_called_with() - assert worker._runner.server is None - - -async def test__run_ok_max_requests_exceeded(worker, loop, - aiohttp_unused_port): - skip_if_no_dict(loop) - - worker.ppid = os.getppid() - worker.alive = True - worker.servers = {} - sock = socket.socket() - addr = ('localhost', aiohttp_unused_port()) - sock.bind(addr) - worker.sockets = [sock] - worker.log = mock.Mock() - worker.loop = loop - worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT - worker.cfg.max_requests = 10 - worker.cfg.is_ssl = False - - worker._runner = web.AppRunner(worker.wsgi) - await worker._runner.setup() - worker._runner.server.requests_count = 30 - - await worker._run() - - worker.notify.assert_called_with() - worker.log.info.assert_called_with("Max requests, shutting down: %s", - worker) - - assert worker._runner.server is None def test__create_ssl_context_without_certs_and_ciphers(worker) -> None: