diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 7d2cf44a61..2f7494cbaa 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1816,6 +1816,45 @@ to spawn a child thread, and then use a :ref:`memory channel .. literalinclude:: reference-core/from-thread-example.py +.. _worker_processes: + +Worker Processes +---------------- + +Given that Python (and CPython in particular) has ongoing difficulties with +CPU-bound work, Trio provides a method to dispatch synchronous function execution to +special subprocesses known as "Worker Processes". By default, Trio will create as many +workers as the system has CPUs (as reported by :func:`os.cpu_count`), allowing fair +and truly parallel dispatch of CPU-bound work. As with Trio threads, these processes +are cached in a process pool to minimize latency and resource usage. Despite this, +executing a function in a process is at best an order of magnitude slower than in +a thread, and possibly even slower when dealing with large arguments or a cold pool. +Therefore, we recommend avoiding worker process dispatch for functions with a +duration of less than about 10 ms. + +Unlike threads, subprocesses are strongly isolated from the parent process, which +allows two important features that cannot be portably implemented in threads: + + - Forceful cancellation: a deadlocked call or infinite loop can be cancelled + by completely terminating the process. + - Protection from errors: if a call segfaults or an extension module has an + unrecoverable error, the worker may die but Trio will raise + :exc:`BrokenWorkerError` and carry on. + +In both cases the workers die suddenly and violently, and at an unpredictable point +in the execution of the dispatched function, so avoid using the cancellation feature +if loss of intermediate results, writes to the filesystem, or shared memory writes +may leave the larger system in an incoherent state. + +.. module:: trio.to_process +.. currentmodule:: trio + +Putting CPU-bound functions in worker processes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: trio.to_process.run_sync + +.. autofunction:: trio.to_process.current_default_process_limiter Exceptions and warnings ----------------------- @@ -1834,6 +1873,8 @@ Exceptions and warnings .. autoexception:: BrokenResourceError +.. autoexception:: BrokenWorkerError + .. autoexception:: RunFinishedError .. autoexception:: TrioInternalError diff --git a/newsfragments/1781.headline.rst b/newsfragments/1781.headline.rst new file mode 100644 index 0000000000..699040f32b --- /dev/null +++ b/newsfragments/1781.headline.rst @@ -0,0 +1,4 @@ +Trio now provides `multiprocessing` based worker processes for the delegation +of cpu-bound work via :func:`trio.to_process.run_sync`. The API is comparable to +:func:`trio.to_thread.run_sync` but true cancellation is achieved by killing the +worker process. \ No newline at end of file diff --git a/trio/__init__.py b/trio/__init__.py index d66ffceea9..5600bf0810 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -90,12 +90,15 @@ from ._deprecate import TrioDeprecationWarning +from ._worker_processes import BrokenWorkerError + # Submodules imported by default from . import lowlevel from . import socket from . import abc from . import from_thread from . import to_thread +from . import to_process # Not imported by default, but mentioned here so static analysis tools like # pylint will know that it exists. diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index a1071519e9..2e5e7d27cd 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -179,6 +179,25 @@ LPOVERLAPPED lpOverlapped ); +BOOL PeekNamedPipe( + HANDLE hNamedPipe, + LPVOID lpBuffer, + DWORD nBufferSize, + LPDWORD lpBytesRead, + LPDWORD lpTotalBytesAvail, + LPDWORD lpBytesLeftThisMessage +); + +BOOL GetNamedPipeHandleStateA( + HANDLE hNamedPipe, + LPDWORD lpState, + LPDWORD lpCurInstances, + LPDWORD lpMaxCollectionCount, + LPDWORD lpCollectDataTimeout, + LPSTR lpUserName, + DWORD nMaxUserNameSize +); + // From https://github.com/piscisaureus/wepoll/blob/master/src/afd.h typedef struct _AFD_POLL_HANDLE_INFO { HANDLE Handle; @@ -245,6 +264,7 @@ class ErrorCodes(enum.IntEnum): ERROR_INVALID_PARMETER = 87 ERROR_NOT_FOUND = 1168 ERROR_NOT_SOCKET = 10038 + ERROR_MORE_DATA = 234 class FileFlags(enum.IntEnum): @@ -296,6 +316,13 @@ class IoControlCodes(enum.IntEnum): IOCTL_AFD_POLL = 0x00012024 +class PipeModes(enum.IntFlag): + PIPE_WAIT = 0x00000000 + PIPE_NOWAIT = 0x00000001 + PIPE_READMODE_BYTE = 0x00000000 + PIPE_READMODE_MESSAGE = 0x00000002 + + ################################################################ # Generic helpers ################################################################ @@ -321,3 +348,21 @@ def raise_winerror(winerror=None, *, filename=None, filename2=None): _, msg = ffi.getwinerror(winerror) # https://docs.python.org/3/library/exceptions.html#OSError raise OSError(0, msg, filename, winerror, filename2) + + +def get_pipe_state(handle): + lpState = ffi.new("LPDWORD") + if not kernel32.GetNamedPipeHandleStateA( + _handle(handle), lpState, ffi.NULL, ffi.NULL, ffi.NULL, ffi.NULL, 0 + ): + raise_winerror() # pragma: no cover + return lpState[0] + + +def peek_pipe_message_left(handle): + left = ffi.new("LPDWORD") + if not kernel32.PeekNamedPipe( + _handle(handle), ffi.NULL, 0, ffi.NULL, ffi.NULL, left + ): + raise_winerror() # pragma: no cover + return left[0] diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index fb420535f4..e04ef84d32 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -1,9 +1,16 @@ import sys from typing import TYPE_CHECKING from . import _core -from ._abc import SendStream, ReceiveStream +from ._abc import SendStream, ReceiveStream, SendChannel, ReceiveChannel from ._util import ConflictDetector, Final -from ._core._windows_cffi import _handle, raise_winerror, kernel32, ffi +from ._core._windows_cffi import ( + _handle, + raise_winerror, + kernel32, + ffi, + ErrorCodes, + peek_pipe_message_left, +) assert sys.platform == "win32" or not TYPE_CHECKING @@ -132,3 +139,70 @@ async def receive_some(self, max_bytes=None) -> bytes: async def aclose(self): await self._handle_holder.aclose() + + +class PipeSendChannel(SendChannel[bytes]): + """Represents a message stream over a pipe object.""" + + def __init__(self, handle: int) -> None: + self._pss = PipeSendStream(handle) + # needed for "detach" via _handle_holder.handle = -1 + self._handle_holder = self._pss._handle_holder + + async def send(self, value: bytes): + # Works just fine if the pipe is message-oriented + await self._pss.send_all(value) + + async def aclose(self): + await self._handle_holder.aclose() + + +class PipeReceiveChannel(ReceiveChannel[bytes]): + """Represents a message stream over a pipe object.""" + + def __init__(self, handle: int) -> None: + self._handle_holder = _HandleHolder(handle) + self._conflict_detector = ConflictDetector( + "another task is currently using this pipe" + ) + + async def receive(self) -> bytes: + with self._conflict_detector: + buffer = bytearray(DEFAULT_RECEIVE_SIZE) + try: + received = await self._receive_some_into(buffer) + except OSError as e: + if e.winerror != ErrorCodes.ERROR_MORE_DATA: + raise # pragma: no cover + left = peek_pipe_message_left(self._handle_holder.handle) + # preallocate memory to avoid an extra copy of very large messages + newbuffer = bytearray(DEFAULT_RECEIVE_SIZE + left) + with memoryview(newbuffer) as view: + view[:DEFAULT_RECEIVE_SIZE] = buffer + await self._receive_some_into(view[DEFAULT_RECEIVE_SIZE:]) + return newbuffer + else: + del buffer[received:] + return buffer + + async def _receive_some_into(self, buffer) -> bytes: + if self._handle_holder.closed: + raise _core.ClosedResourceError("this pipe is already closed") + try: + return await _core.readinto_overlapped(self._handle_holder.handle, buffer) + except BrokenPipeError: + if self._handle_holder.closed: + raise _core.ClosedResourceError( + "another task closed this pipe" + ) from None + + # Windows raises BrokenPipeError on one end of a pipe + # whenever the other end closes, regardless of direction. + # Convert this to EndOfChannel. + # + # Do we have to checkpoint manually? We are raising an exception. + await _core.checkpoint() + raise _core.EndOfChannel + + async def aclose(self): + await self._handle_holder.aclose() diff --git a/trio/_worker_processes.py b/trio/_worker_processes.py new file mode 100644 index 0000000000..91da86b842 --- /dev/null +++ b/trio/_worker_processes.py @@ -0,0 +1,341 @@ +import os +from collections import deque +from itertools import count +from multiprocessing import Pipe, Process +from multiprocessing.reduction import ForkingPickler + +from ._core import ( + open_nursery, + RunVar, + CancelScope, + wait_readable, + EndOfChannel, + BrokenResourceError, + checkpoint_if_cancelled, +) +from ._sync import CapacityLimiter +from ._threads import to_thread_run_sync +from ._timeouts import sleep_forever + +_limiter_local = RunVar("proc_limiter") + +# How long a process will idle waiting for new work before gives up and exits. +# This should be longer than a thread timeout proportionately to startup time. +IDLE_TIMEOUT = 60 * 10 + +# Sane default might be to expect cpu-bound work +DEFAULT_LIMIT = os.cpu_count() +_proc_counter = count() + +if os.name == "nt": + from trio._windows_pipes import PipeSendChannel, PipeReceiveChannel + from ._wait_for_object import WaitForSingleObject + + # TODO: This uses a thread per-process. Can we do better? + wait_sentinel = WaitForSingleObject +else: + from ._unix_pipes import FdStream + import struct + + wait_sentinel = wait_readable + + +class BrokenWorkerError(RuntimeError): + """Raised when a worker process fails or dies unexpectedly. + + This error is not typically encountered in normal use, and indicates a severe + failure of either Trio or the code that was executing in the worker. + """ + + pass + + +def current_default_process_limiter(): + """Get the default `~trio.CapacityLimiter` used by + `trio.to_process.run_sync`. + + The most common reason to call this would be if you want to modify its + :attr:`~trio.CapacityLimiter.total_tokens` attribute. This attribute + is initialized to the number of CPUs reported by :func:`os.cpu_count`. + + """ + try: + limiter = _limiter_local.get() + except LookupError: + limiter = CapacityLimiter(DEFAULT_LIMIT) + _limiter_local.set(limiter) + return limiter + + +class ProcCache: + def __init__(self): + # The cache is a deque rather than dict here since processes can't remove + # themselves anyways, so we don't need O(1) lookups + self._cache = deque() + # NOTE: avoid thread races between Trio runs by only interacting with + # self._cache via thread-atomic actions like append, pop, del + + def prune(self): + # take advantage of the oldest proc being on the left to + # keep iteration O(dead workers) + try: + while True: + proc = self._cache.popleft() + if proc.is_alive(): + self._cache.appendleft(proc) + return + except IndexError: + # Thread safety: it's necessary to end the iteration using this error + # when the cache is empty, as opposed to `while self._cache`. + pass + + def push(self, proc): + self._cache.append(proc) + + def pop(self): + """Get live worker process or raise IndexError""" + while True: + proc = self._cache.pop() + if proc.is_alive(): + return proc + + def __len__(self): + return len(self._cache) + + +PROC_CACHE = ProcCache() + + +class WorkerProc: + def __init__(self): + child_recv_pipe, self._send_pipe = Pipe(duplex=False) + self._recv_pipe, child_send_pipe = Pipe(duplex=False) + self._proc = Process( + target=self._work, + args=(child_recv_pipe, child_send_pipe), + name=f"Trio worker process {next(_proc_counter)}", + daemon=True, + ) + # The following initialization methods may take a long time + self._proc.start() + + @staticmethod + def _work(recv_pipe, send_pipe): # pragma: no cover + + import inspect + import outcome + + def worker_fn(): + ret = fn(*args) + if inspect.iscoroutine(ret): + # Manually close coroutine to avoid RuntimeWarnings + ret.close() + raise TypeError( + "Trio expected a sync function, but {!r} appears to be " + "asynchronous".format(getattr(fn, "__qualname__", fn)) + ) + + return ret + + try: + while recv_pipe.poll(timeout=IDLE_TIMEOUT): + fn, args = recv_pipe.recv() + result = outcome.capture(worker_fn) + # Unlike the thread cache, it's impossible to deliver the + # result from the worker process. So shove it onto the queue + # and hope the receiver delivers the result and marks us idle + send_pipe.send(result) + + del fn + del args + del result + finally: + recv_pipe.close() + send_pipe.close() + + async def run_sync(self, sync_fn, *args): + # Neither this nor the child process should be waiting at this point + self._rehabilitate_pipes() + async with open_nursery() as nursery: + # Monitor needed for pypy and other platforms that don't + # promptly raise EndOfChannel + nursery.start_soon(self._child_monitor) + try: + await self._send(ForkingPickler.dumps((sync_fn, args))) + result = ForkingPickler.loads(await self._recv()) + except EndOfChannel: + # Likely the worker died while we were waiting on a pipe + self.kill() # Just make sure + # sleep and let the monitor raise the appropriate error to avoid + # creating any MultiErrors in this codepath + await sleep_forever() + except BaseException: + # Cancellation leaves the process in an unknown state, so + # there is no choice but to kill, anyway it frees the pipe threads. + # For other unknown errors, it's best to clean up similarly. + self.kill() + raise + # Must cancel the _child_monitor task to escape the nursery + nursery.cancel_scope.cancel() + return result.unwrap() + + async def _child_monitor(self): + # If this handle becomes ready, raise a catchable error... + await wait_sentinel(self._proc.sentinel) + # but not if another error or cancel is incoming, those take priority! + await checkpoint_if_cancelled() + raise BrokenWorkerError(f"{self._proc} died unexpectedly") + + def is_alive(self): + # Even if the proc is alive, there is a race condition where it could + # be dying, use join to make sure if necessary. + return self._proc.is_alive() + + def kill(self): + try: + self._proc.kill() + except AttributeError: + self._proc.terminate() + + def join(self, timeout=None): + # Needed for some tests. We have to reach in deeply because + # _proc.join() doesn't report whether the join was successful + return self._proc._popen.wait(timeout) is not None + + if os.name == "nt": + + def _rehabilitate_pipes(self): + # These must be created in an async context, so defer so + # that this object can be instantiated in e.g. a thread + if not hasattr(self, "_send_chan"): + self._send_chan = PipeSendChannel(self._send_pipe.fileno()) + self._recv_chan = PipeReceiveChannel(self._recv_pipe.fileno()) + self._send = self._send_chan.send + self._recv = self._recv_chan.receive + + def __del__(self): + # Avoid __del__ errors on cleanup: GH#174, GH#1767 + # multiprocessing will close them for us + if hasattr(self, "_send_chan"): + self._send_chan._handle_holder.handle = -1 + self._recv_chan._handle_holder.handle = -1 + + else: + + def _rehabilitate_pipes(self): + # These must be created in an async context, so defer so + # that this object can be instantiated in e.g. a thread + if not hasattr(self, "_send_stream"): + self._send_stream = FdStream(self._send_pipe.fileno()) + self._recv_stream = FdStream(self._recv_pipe.fileno()) + + async def _recv(self): + buf = await self._recv_exactly(4) + (size,) = struct.unpack("!i", buf) + if size == -1: + buf = await self._recv_exactly(8) + (size,) = struct.unpack("!Q", buf) + return await self._recv_exactly(size) + + async def _recv_exactly(self, size): + result_bytes = bytearray() + while size: + partial_result = await self._recv_stream.receive_some(size) + num_recvd = len(partial_result) + if not num_recvd: + raise EndOfChannel("got end of file during message") + result_bytes.extend(partial_result) + if num_recvd > size: # pragma: no cover + raise RuntimeError("Oversized response") + else: + size -= num_recvd + return result_bytes + + async def _send(self, buf): + n = len(buf) + if n > 0x7FFFFFFF: + pre_header = struct.pack("!i", -1) + header = struct.pack("!Q", n) + await self._send_stream.send_all(pre_header) + await self._send_stream.send_all(header) + await self._send_stream.send_all(buf) + else: + # For wire compatibility with 3.7 and lower + header = struct.pack("!i", n) + if n > 16384: + # The payload is large so Nagle's algorithm won't be triggered + # and we'd better avoid the cost of concatenation. + await self._send_stream.send_all(header) + await self._send_stream.send_all(buf) + else: + # Issue #20540: concatenate before sending, to avoid delays due + # to Nagle's algorithm on a TCP socket. + # Also note we want to avoid sending a 0-length buffer separately, + # to avoid "broken pipe" errors if the other end closed the pipe. + await self._send_stream.send_all(header + buf) + + def __del__(self): + # Avoid __del__ errors on cleanup: GH#174, GH#1767 + # multiprocessing will close them for us + if hasattr(self, "_send_stream"): + self._send_stream._fd_holder.fd = -1 + self._recv_stream._fd_holder.fd = -1 + + +async def to_process_run_sync(sync_fn, *args, cancellable=False, limiter=None): + """Run sync_fn in a separate process + + This is a wrapping of :class:`multiprocessing.Process` that follows the API of + :func:`trio.to_thread.run_sync`. The intended use of this function is limited: + + - Circumvent the GIL to run CPU-bound functions in parallel + - Make blocking APIs or infinite loops truly cancellable through + SIGKILL/TerminateProcess without leaking resources + - Protect the main process from untrusted/unstable code without leaks + + Other :mod:`multiprocessing` features may work but are not officially + supported by Trio, and all the normal :mod:`multiprocessing` caveats apply. + + Args: + sync_fn: An importable or pickleable synchronous callable. See the + :mod:`multiprocessing` documentation for detailed explanation of + limitations. + *args: Positional arguments to pass to sync_fn. If you need keyword + arguments, use :func:`functools.partial`. + cancellable (bool): Whether to allow cancellation of this operation. + Cancellation always involves abrupt termination of the worker process + with SIGKILL/TerminateProcess. + limiter (None, or CapacityLimiter): + An object used to limit the number of simultaneous processes. Most + commonly this will be a `~trio.CapacityLimiter`, but any async + context manager will succeed. + + Returns: + Whatever ``sync_fn(*args)`` returns. + + Raises: + Exception: Whatever ``sync_fn(*args)`` raises. + + """ + if limiter is None: + limiter = current_default_process_limiter() + + async with limiter: + PROC_CACHE.prune() + + while True: + try: + proc = PROC_CACHE.pop() + except IndexError: + proc = await to_thread_run_sync(WorkerProc) + + try: + with CancelScope(shield=not cancellable): + return await proc.run_sync(sync_fn, *args) + except BrokenResourceError: + # Rare case where proc timed out even though it was still alive + # as we popped it. Just retry. + pass + finally: + if proc.is_alive(): + PROC_CACHE.push(proc) diff --git a/trio/tests/test_windows_pipes.py b/trio/tests/test_windows_pipes.py index 361cd64ce2..fe3fa985f1 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/tests/test_windows_pipes.py @@ -10,30 +10,63 @@ from ..testing import wait_all_tasks_blocked, check_one_way_stream if sys.platform == "win32": - from .._windows_pipes import PipeSendStream, PipeReceiveStream - from .._core._windows_cffi import _handle, kernel32 + from .._windows_pipes import ( + PipeSendStream, + PipeReceiveStream, + PipeSendChannel, + PipeReceiveChannel, + DEFAULT_RECEIVE_SIZE, + ) + from .._core._windows_cffi import ( + _handle, + kernel32, + PipeModes, + get_pipe_state, + ) from asyncio.windows_utils import pipe + from multiprocessing.connection import Pipe else: pytestmark = pytest.mark.skip(reason="windows only") pipe = None # type: Any PipeSendStream = None # type: Any PipeReceiveStream = None # type: Any + PipeSendChannel = None # type: Any + PipeReceiveChannel = None # type: Any -async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]": - """Makes a new pair of pipes.""" +async def make_pipe_stream() -> "Tuple[PipeSendStream, PipeReceiveStream]": + """Makes a new pair of byte-oriented pipes.""" (r, w) = pipe() + assert not (PipeModes.PIPE_READMODE_MESSAGE & get_pipe_state(r)) return PipeSendStream(w), PipeReceiveStream(r) +async def make_pipe_channel() -> "Tuple[PipeSendChannel, PipeReceiveChannel]": + """Makes a new pair of message-oriented pipes.""" + (r_channel, w_channel) = Pipe(duplex=False) + (r, w) = r_channel.fileno(), w_channel.fileno() + # XXX: Check internal details haven't changed suddenly + assert (r_channel._handle, w_channel._handle) == (r, w) + # XXX: Sabotage _ConnectionBase __del__ + (r_channel._handle, w_channel._handle) = (None, None) + # XXX: Check internal details haven't changed suddenly + assert r_channel.closed and w_channel.closed + assert PipeModes.PIPE_READMODE_MESSAGE & get_pipe_state(r) + return PipeSendChannel(w), PipeReceiveChannel(r) + + async def test_pipe_typecheck(): with pytest.raises(TypeError): PipeSendStream(1.0) with pytest.raises(TypeError): PipeReceiveStream(None) + with pytest.raises(TypeError): + PipeSendChannel(1.0) + with pytest.raises(TypeError): + PipeReceiveChannel(None) -async def test_pipe_error_on_close(): +async def test_pipe_stream_error_on_close(): # Make sure we correctly handle a failure from kernel32.CloseHandle r, w = pipe() @@ -49,8 +82,40 @@ async def test_pipe_error_on_close(): await receive_stream.aclose() -async def test_pipes_combined(): - write, read = await make_pipe() +async def test_pipe_channel_error_on_close(): + # Make sure we correctly handle a failure from kernel32.CloseHandle + send_channel, receive_channel = await make_pipe_channel() + + assert kernel32.CloseHandle(_handle(receive_channel._handle_holder.handle)) + assert kernel32.CloseHandle(_handle(send_channel._handle_holder.handle)) + + with pytest.raises(OSError): + await send_channel.aclose() + with pytest.raises(OSError): + await receive_channel.aclose() + + +async def test_closed_resource_error(): + send_stream, receive_stream = await make_pipe_stream() + + await send_stream.aclose() + with pytest.raises(_core.ClosedResourceError): + await send_stream.send_all(b"Hello") + + send_channel, receive_channel = await make_pipe_channel() + + with pytest.raises(_core.ClosedResourceError): + async with _core.open_nursery() as nursery: + nursery.start_soon(receive_channel.receive) + await wait_all_tasks_blocked(0.01) + await receive_channel.aclose() + await send_channel.aclose() + with pytest.raises(_core.ClosedResourceError): + await send_channel.send(b"Hello") + + +async def test_pipe_streams_combined(): + write, read = await make_pipe_stream() count = 2 ** 20 replicas = 3 @@ -73,13 +138,37 @@ async def reader(): assert total_received == count * replicas - async with _core.open_nursery() as n: - n.start_soon(sender) - n.start_soon(reader) + async with _core.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(reader) + + +async def test_pipe_channels_combined(): + async def sender(): + async with write: + b = bytearray(count) + for _ in range(replicas): + await write.send(b) + + async def reader(): + async with read: + await wait_all_tasks_blocked() + total_received = 0 + async for b in read: + total_received += len(b) + + assert total_received == count * replicas + + for count in (8, DEFAULT_RECEIVE_SIZE, 2 ** 20): + for replicas in (1, 2, 3): + write, read = await make_pipe_channel() + async with _core.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(reader) -async def test_async_with(): - w, r = await make_pipe() +async def test_async_with_stream(): + w, r = await make_pipe_stream() async with w, r: pass @@ -89,8 +178,19 @@ async def test_async_with(): await r.receive_some(10) -async def test_close_during_write(): - w, r = await make_pipe() +async def test_async_with_channel(): + w, r = await make_pipe_channel() + async with w, r: + pass + + with pytest.raises(_core.ClosedResourceError): + await w.send(None) + with pytest.raises(_core.ClosedResourceError): + await r.receive() + + +async def test_close_stream_during_write(): + w, r = await make_pipe_stream() async with _core.open_nursery() as nursery: async def write_forever(): @@ -104,7 +204,22 @@ async def write_forever(): await w.aclose() +async def test_close_channel_during_write(): + w, r = await make_pipe_channel() + async with _core.open_nursery() as nursery: + + async def write_forever(): + with pytest.raises(_core.ClosedResourceError) as excinfo: + while True: + await w.send(b"x" * 4096) + assert "another task" in str(excinfo.value) + + nursery.start_soon(write_forever) + await wait_all_tasks_blocked(0.1) + await w.aclose() + + async def test_pipe_fully(): # passing make_clogged_pipe tests wait_send_all_might_not_block, and we # can't implement that on Windows - await check_one_way_stream(make_pipe, None) + await check_one_way_stream(make_pipe_stream, None) diff --git a/trio/tests/test_worker_process.py b/trio/tests/test_worker_process.py new file mode 100644 index 0000000000..e954bf6b13 --- /dev/null +++ b/trio/tests/test_worker_process.py @@ -0,0 +1,277 @@ +import multiprocessing +import os + +import pytest + +from .. import _core, BrokenResourceError +from .._sync import CapacityLimiter +from .._timeouts import fail_after, TooSlowError +from .. import _worker_processes +from .._core.tests.tutil import slow +from .._worker_processes import ( + to_process_run_sync, + current_default_process_limiter, + BrokenWorkerError, +) +from ..testing import wait_all_tasks_blocked +from .._threads import to_thread_run_sync + + +@pytest.fixture(autouse=True) +def empty_proc_cache(): + while True: + try: + proc = _worker_processes.PROC_CACHE.pop() + proc.kill() + except IndexError: + return + + +def _echo_and_pid(x): # pragma: no cover + return (x, os.getpid()) + + +def _raise_pid(): # pragma: no cover + raise ValueError(os.getpid()) + + +@slow +async def test_run_in_worker_process(): + trio_pid = os.getpid() + limiter = CapacityLimiter(1) + + x, child_pid = await to_process_run_sync(_echo_and_pid, 1, limiter=limiter) + assert x == 1 + assert child_pid != trio_pid + + with pytest.raises(ValueError) as excinfo: + await to_process_run_sync(_raise_pid, limiter=limiter) + print(excinfo.value.args) + assert excinfo.value.args[0] != trio_pid + + +def _block_proc_on_queue(q, ev, done_ev): # pragma: no cover + # Make the process block for a controlled amount of time + ev.set() + q.get() + done_ev.set() + + +@slow +async def test_run_in_worker_process_cancellation(capfd): + async def child(q, ev, done_ev, cancellable): + print("start") + try: + return await to_process_run_sync( + _block_proc_on_queue, q, ev, done_ev, cancellable=cancellable + ) + finally: + print("exit") + + m = multiprocessing.Manager() + q = m.Queue() + ev = m.Event() + done_ev = m.Event() + + # This one can't be cancelled + async with _core.open_nursery() as nursery: + nursery.start_soon(child, q, ev, done_ev, False) + await to_thread_run_sync(ev.wait, cancellable=True) + nursery.cancel_scope.cancel() + with _core.CancelScope(shield=True): + await wait_all_tasks_blocked(0.01) + # It's still running + assert not done_ev.is_set() + q.put(None) + # Now it exits + + ev = m.Event() + done_ev = m.Event() + # But if we cancel *before* it enters, the entry is itself a cancellation + # point + with _core.CancelScope() as scope: + scope.cancel() + await child(q, ev, done_ev, False) + assert scope.cancelled_caught + capfd.readouterr() + + ev = m.Event() + done_ev = m.Event() + # This is truly cancellable by killing the process + async with _core.open_nursery() as nursery: + nursery.start_soon(child, q, ev, done_ev, True) + # Give it a chance to get started. (This is important because + # to_thread_run_sync does a checkpoint_if_cancelled before + # blocking on the thread, and we don't want to trigger this.) + await wait_all_tasks_blocked() + assert capfd.readouterr().out.rstrip() == "start" + await to_thread_run_sync(ev.wait, cancellable=True) + # Then cancel it. + nursery.cancel_scope.cancel() + # The task exited, but the process died + assert not done_ev.is_set() + assert capfd.readouterr().out.rstrip() == "exit" + + +def _null_func(): # pragma: no cover + pass + + +async def test_run_in_worker_process_fail_to_spawn(monkeypatch): + # Test the unlikely but possible case where trying to spawn a worker fails + def bad_start(): + raise RuntimeError("the engines canna take it captain") + + monkeypatch.setattr(_worker_processes, "WorkerProc", bad_start) + + limiter = current_default_process_limiter() + assert limiter.borrowed_tokens == 0 + + # We get an appropriate error, and the limiter is cleanly released + with pytest.raises(RuntimeError) as excinfo: + await to_process_run_sync(_null_func) # pragma: no cover + assert "engines" in str(excinfo.value) + + assert limiter.borrowed_tokens == 0 + + +async def _null_async_fn(): # pragma: no cover + pass + + +@slow +async def test_trio_to_process_run_sync_expected_error(): + with pytest.raises(TypeError, match="expected a sync function"): + await to_process_run_sync(_null_async_fn) + + +def _segfault_out_of_bounds_pointer(): # pragma: no cover + # https://wiki.python.org/moin/CrashingPython + import ctypes + + i = ctypes.c_char(b"a") + j = ctypes.pointer(i) + c = 0 + while True: + j[c] = b"a" + c += 1 + + +@slow +async def test_to_process_run_sync_raises_on_segfault(): + # This test was flaky on CI across several platforms and implementations. + # I can reproduce it locally if there is some other process using the rest + # of the CPU (F@H in this case) although I cannot explain why running this + # on a busy machine would change the number of iterations (40-50k) needed + # for the OS to notice there is something funny going on with memory access. + # The usual symptom was for the segfault to occur, but the process + # to fail to raise the error for more than one minute, which would + # stall the test runner for 10 minutes. + # Here we raise our own failure error before the test runner timeout (55s) + # but xfail if we actually have to timeout. + try: + with fail_after(55): + await to_process_run_sync(_segfault_out_of_bounds_pointer, cancellable=True) + except BrokenWorkerError: + pass + except TooSlowError: # pragma: no cover + pytest.xfail("Unable to cause segfault after 55 seconds.") + else: # pragma: no cover + pytest.fail("No error was raised on segfault.") + + +def _never_halts(ev): # pragma: no cover + # important difference from blocking call is cpu usage + ev.set() + while True: + pass + + +@slow +async def test_to_process_run_sync_cancel_infinite_loop(): + m = multiprocessing.Manager() + ev = m.Event() + + async def child(): + await to_process_run_sync(_never_halts, ev, cancellable=True) + + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + await to_thread_run_sync(ev.wait, cancellable=True) + nursery.cancel_scope.cancel() + + +@slow +async def test_to_process_run_sync_raises_on_kill(): + m = multiprocessing.Manager() + ev = m.Event() + + async def child(): + await to_process_run_sync(_never_halts, ev) + + await to_process_run_sync(_null_func) + proc = _worker_processes.PROC_CACHE._cache[0] + with pytest.raises(BrokenWorkerError): + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + await to_thread_run_sync(ev.wait) + proc.kill() + + +@slow +async def test_spawn_worker_in_thread_and_prune_cache(): + # make sure we can successfully put worker spawning in a trio thread + proc = await to_thread_run_sync(_worker_processes.WorkerProc) + # take it's number and kill it for the next test + pid1 = proc._proc.pid + proc.kill() + assert proc.join(1) + # put dead proc into the cache (normal code never does this) + _worker_processes.PROC_CACHE.push(proc) + # dead procs shouldn't pop out + with pytest.raises(IndexError): + _worker_processes.PROC_CACHE.pop() + _worker_processes.PROC_CACHE.push(proc) + # should spawn a new worker and remove the dead one + _, pid2 = await to_process_run_sync(_echo_and_pid, None) + assert len(_worker_processes.PROC_CACHE) == 1 + assert pid1 != pid2 + + +@slow +async def test_to_process_run_sync_large_job(): + n = 2 ** 20 + x, _ = await to_process_run_sync(_echo_and_pid, bytearray(n)) + assert len(x) == n + + +async def test_exhaustively_cancel_run_sync(): + # to test that cancellation does not ever leave a living process behind + # currently requires manually targeting all but last checkpoints + m = multiprocessing.Manager() + ev = m.Event() + + # cancel at job send + async def fake_monitor(): + c.cancel() + + proc = _worker_processes.WorkerProc() + proc._child_monitor = fake_monitor + with _core.CancelScope() as c: + await proc.run_sync(_never_halts, ev) + assert proc.join(1) + + # cancel at result recv is tested elsewhere + + +def _shorten_timeout(): # pragma: no cover + _worker_processes.IDLE_TIMEOUT = 0 + + +@slow +async def test_racing_timeout(): + proc = _worker_processes.WorkerProc() + await proc.run_sync(_shorten_timeout) + assert proc.join(10) + with pytest.raises(BrokenResourceError): + await proc.run_sync(_null_func) diff --git a/trio/to_process.py b/trio/to_process.py new file mode 100644 index 0000000000..e08c6faf22 --- /dev/null +++ b/trio/to_process.py @@ -0,0 +1,2 @@ +from ._worker_processes import to_process_run_sync as run_sync +from ._worker_processes import current_default_process_limiter