diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e69de29..2603c6f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @madkinsz @desertaxle @ahuang11 diff --git a/.gitignore b/.gitignore index b96a3be..8cded71 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# dask output +dask-worker-space + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/CHANGELOG.md b/CHANGELOG.md index f47addc..fda999b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,4 +25,4 @@ Released on ????? ?th, 20??. ### Added -- `task_name` task - [#1](https://github.com/PrefectHQ/prefect-dask/pull/1) +- Migrated `DaskTaskRunner` - [#2](https://github.com/PrefectHQ/prefect-dask/pull/2) diff --git a/README.md b/README.md index f450bde..de6a4ae 100644 --- a/README.md +++ b/README.md @@ -25,19 +25,25 @@ pip install prefect-dask ### Write and run a flow ```python -from prefect import flow -from prefect_dask.tasks import ( - goodbye_prefect_dask, - hello_prefect_dask, -) +from prefect import flow, task +from prefect_dask.task_runners import DaskTaskRunner +@task +def say_hello(name): + print(f"hello {name}") -@flow -def example_flow(): - hello_prefect_dask - goodbye_prefect_dask +@task +def say_goodbye(name): + print(f"goodbye {name}") -example_flow() +@flow(task_runner=DaskTaskRunner()) +def greetings(names): + for name in names: + say_hello(name) + say_goodbye(name) + +if __name__ == "__main__": + greetings(["arthur", "trillian", "ford", "marvin"]) ``` ## Resources diff --git a/docs/task_runners.md b/docs/task_runners.md new file mode 100644 index 0000000..a44690e --- /dev/null +++ b/docs/task_runners.md @@ -0,0 +1 @@ +::: prefect_dask.task_runners \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 8b53ccf..d8f1ae6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,3 +36,4 @@ plugins: nav: - Home: index.md + - Task Runners: task_runners.md diff --git a/prefect_dask/__init__.py b/prefect_dask/__init__.py index 4d52a61..bfd452d 100644 --- a/prefect_dask/__init__.py +++ b/prefect_dask/__init__.py @@ -1,3 +1,4 @@ from . import _version +from .task_runners import DaskTaskRunner # noqa __version__ = _version.get_versions()["version"] diff --git a/prefect_dask/task_runners.py b/prefect_dask/task_runners.py new file mode 100644 index 0000000..d424fee --- /dev/null +++ b/prefect_dask/task_runners.py @@ -0,0 +1,279 @@ +""" +Interface and implementations of the Dask Task Runner. +[Task Runners](https://orion-docs.prefect.io/api-ref/prefect/task-runners/) +in Prefect are responsible for managing the execution of Prefect task runs. +Generally speaking, users are not expected to interact with +task runners outside of configuring and initializing them for a flow. +Example: + >>> from prefect import flow, task + >>> from prefect.task_runners import SequentialTaskRunner + >>> from typing import List + >>> + >>> @task + >>> def say_hello(name): + ... print(f"hello {name}") + >>> + >>> @task + >>> def say_goodbye(name): + ... print(f"goodbye {name}") + >>> + >>> @flow(task_runner=SequentialTaskRunner()) + >>> def greetings(names: List[str]): + ... for name in names: + ... say_hello(name) + ... say_goodbye(name) + Switching to a `DaskTaskRunner`: + >>> from prefect_dask.task_runners import DaskTaskRunner + >>> flow.task_runner = DaskTaskRunner() + >>> greetings(["arthur", "trillian", "ford", "marvin"]) + hello arthur + goodbye arthur + hello trillian + hello ford + goodbye marvin + hello marvin + goodbye ford + goodbye trillian +""" + +from contextlib import AsyncExitStack +from typing import Any, Awaitable, Callable, Dict, Optional, Union +from uuid import UUID + +import distributed +from prefect.futures import PrefectFuture +from prefect.orion.schemas.core import TaskRun +from prefect.orion.schemas.states import State +from prefect.states import exception_to_crashed_state +from prefect.task_runners import BaseTaskRunner, R, TaskConcurrencyType +from prefect.utilities.asyncio import A +from prefect.utilities.collections import visit_collection +from prefect.utilities.hashing import to_qualified_name +from prefect.utilities.importtools import import_object + + +class DaskTaskRunner(BaseTaskRunner): + """ + A parallel task_runner that submits tasks to the `dask.distributed` scheduler. + By default a temporary `distributed.LocalCluster` is created (and + subsequently torn down) within the `start()` contextmanager. To use a + different cluster class (e.g. + [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can + specify `cluster_class`/`cluster_kwargs`. + Alternatively, if you already have a dask cluster running, you can provide + the address of the scheduler via the `address` kwarg. + !!! warning "Multiprocessing safety" + Note that, because the `DaskTaskRunner` uses multiprocessing, calls to flows + in scripts must be guarded with `if __name__ == "__main__":` or warnings will + be displayed. + Args: + address (string, optional): Address of a currently running dask + scheduler; if one is not provided, a temporary cluster will be + created in `DaskTaskRunner.start()`. Defaults to `None`. + cluster_class (string or callable, optional): The cluster class to use + when creating a temporary dask cluster. Can be either the full + class name (e.g. `"distributed.LocalCluster"`), or the class itself. + cluster_kwargs (dict, optional): Additional kwargs to pass to the + `cluster_class` when creating a temporary dask cluster. + adapt_kwargs (dict, optional): Additional kwargs to pass to `cluster.adapt` + when creating a temporary dask cluster. Note that adaptive scaling + is only enabled if `adapt_kwargs` are provided. + client_kwargs (dict, optional): Additional kwargs to use when creating a + [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client). + Examples: + Using a temporary local dask cluster: + >>> from prefect import flow + >>> from prefect_dask.task_runners import DaskTaskRunner + >>> @flow(task_runner=DaskTaskRunner) + >>> def my_flow(): + >>> ... + Using a temporary cluster running elsewhere. Any Dask cluster class should + work, here we use [dask-cloudprovider](https://cloudprovider.dask.org): + >>> DaskTaskRunner( + >>> cluster_class="dask_cloudprovider.FargateCluster", + >>> cluster_kwargs={ + >>> "image": "prefecthq/prefect:latest", + >>> "n_workers": 5, + >>> }, + >>> ) + Connecting to an existing dask cluster: + >>> DaskTaskRunner(address="192.0.2.255:8786") + """ + + def __init__( + self, + address: str = None, + cluster_class: Union[str, Callable] = None, + cluster_kwargs: dict = None, + adapt_kwargs: dict = None, + client_kwargs: dict = None, + ): + # Validate settings and infer defaults + if address: + if cluster_class or cluster_kwargs or adapt_kwargs: + raise ValueError( + "Cannot specify `address` and " + "`cluster_class`/`cluster_kwargs`/`adapt_kwargs`" + ) + else: + if isinstance(cluster_class, str): + cluster_class = import_object(cluster_class) + else: + cluster_class = cluster_class + + # Create a copies of incoming kwargs since we may mutate them + cluster_kwargs = cluster_kwargs.copy() if cluster_kwargs else {} + adapt_kwargs = adapt_kwargs.copy() if adapt_kwargs else {} + client_kwargs = client_kwargs.copy() if client_kwargs else {} + + # Update kwargs defaults + client_kwargs.setdefault("set_as_default", False) + + # The user cannot specify async/sync themselves + if "asynchronous" in client_kwargs: + raise ValueError( + "`client_kwargs` cannot set `asynchronous`. " + "This option is managed by Prefect." + ) + if "asynchronous" in cluster_kwargs: + raise ValueError( + "`cluster_kwargs` cannot set `asynchronous`. " + "This option is managed by Prefect." + ) + + # Store settings + self.address = address + self.cluster_class = cluster_class + self.cluster_kwargs = cluster_kwargs + self.adapt_kwargs = adapt_kwargs + self.client_kwargs = client_kwargs + + # Runtime attributes + self._client: "distributed.Client" = None + self._cluster: "distributed.deploy.Cluster" = None + self._dask_futures: Dict[UUID, "distributed.Future"] = {} + + super().__init__() + + @property + def concurrency_type(self) -> TaskConcurrencyType: + return ( + TaskConcurrencyType.PARALLEL + if self.cluster_kwargs.get("processes") + else TaskConcurrencyType.CONCURRENT + ) + + async def submit( + self, + task_run: TaskRun, + run_fn: Callable[..., Awaitable[State[R]]], + run_kwargs: Dict[str, Any], + asynchronous: A = True, + ) -> PrefectFuture[R, A]: + if not self._started: + raise RuntimeError( + "The task runner must be started before submitting work." + ) + + # Cast Prefect futures to Dask futures where possible to optimize Dask task + # scheduling + run_kwargs = await self._optimize_futures(run_kwargs) + + self._dask_futures[task_run.id] = self._client.submit( + run_fn, + # Dask displays the text up to the first '-' as the name, include the + # task run id to ensure the key is unique. + key=f"{task_run.name}-{task_run.id.hex}", + # Dask defaults to treating functions are pure, but we set this here for + # explicit expectations. If this task run is submitted to Dask twice, the + # result of the first run should be returned. Subsequent runs would return + # `Abort` exceptions if they were submitted again. + pure=True, + **run_kwargs, + ) + + return PrefectFuture( + task_run=task_run, task_runner=self, asynchronous=asynchronous + ) + + def _get_dask_future(self, prefect_future: PrefectFuture) -> "distributed.Future": + """ + Retrieve the dask future corresponding to a Prefect future. + The Dask future is for the `run_fn`, which should return a `State`. + """ + return self._dask_futures[prefect_future.run_id] + + async def _optimize_futures(self, expr): + async def visit_fn(expr): + if isinstance(expr, PrefectFuture): + dask_future = self._dask_futures.get(expr.run_id) + if dask_future is not None: + return dask_future + # Fallback to return the expression unaltered + return expr + + return await visit_collection(expr, visit_fn=visit_fn, return_data=True) + + async def wait( + self, + prefect_future: PrefectFuture, + timeout: float = None, + ) -> Optional[State]: + future = self._get_dask_future(prefect_future) + try: + return await future.result(timeout=timeout) + except distributed.TimeoutError: + return None + except BaseException as exc: + return exception_to_crashed_state(exc) + + async def _start(self, exit_stack: AsyncExitStack): + """ + Start the task runner and prep for context exit. + - Creates a cluster if an external address is not set. + - Creates a client to connect to the cluster. + - Pushes a call to wait for all running futures to complete on exit. + """ + if self.address: + self.logger.info( + f"Connecting to an existing Dask cluster at {self.address}" + ) + connect_to = self.address + else: + self.cluster_class = self.cluster_class or distributed.LocalCluster + + self.logger.info( + f"Creating a new Dask cluster with " + f"`{to_qualified_name(self.cluster_class)}`" + ) + connect_to = self._cluster = await exit_stack.enter_async_context( + self.cluster_class(asynchronous=True, **self.cluster_kwargs) + ) + if self.adapt_kwargs: + self._cluster.adapt(**self.adapt_kwargs) + + self._client = await exit_stack.enter_async_context( + distributed.Client(connect_to, asynchronous=True, **self.client_kwargs) + ) + + if self._client.dashboard_link: + self.logger.info( + f"The Dask dashboard is available at {self._client.dashboard_link}", + ) + + def __getstate__(self): + """ + Allow the `DaskTaskRunner` to be serialized by dropping + the `distributed.Client`, which contains locks. + Must be deserialized on a dask worker. + """ + data = self.__dict__.copy() + data.update({k: None for k in {"_client", "_cluster"}}) + return data + + def __setstate__(self, data: dict): + """ + Restore the `distributed.Client` by loading the client on a dask worker. + """ + self.__dict__.update(data) + self._client = distributed.get_client() diff --git a/requirements-dev.txt b/requirements-dev.txt index ead4728..0b1f87c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ pytest black flake8 +flaky mypy mkdocs mkdocs-material @@ -11,4 +12,4 @@ pytest-asyncio mock; python_version < '3.8' mkdocs-gen-files interrogate -coverage \ No newline at end of file +coverage diff --git a/requirements.txt b/requirements.txt index 2bd9d1d..f4ec152 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ -prefect>=2.0a13 \ No newline at end of file +prefect @ git+https://github.com/PrefectHQ/prefect@orion +dask==2022.2.0; python_version < '3.8' +dask>=2022.5.0; python_version >= '3.8' diff --git a/setup.cfg b/setup.cfg index 36be797..789cfb6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,8 @@ parentdir_prefix = [tool:interrogate] ignore-init-module = True exclude = prefect_dask/_version.py, tests, setup.py, versioneer.py, docs, site +ignore_init_method = True +ignore_regex = submit,wait,concurrency_type,_optimize_futures fail-under = 95 omit-covered-files = True @@ -36,3 +38,7 @@ show_missing = True [tool:pytest] asyncio_mode = auto + +markers = + service(arg): a service integration test. For example 'docker' + enable_orion_handler: by default, sending logs to the API is disabled. Tests marked with this use the handler. \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_task_runners.py b/tests/test_task_runners.py new file mode 100644 index 0000000..25d9250 --- /dev/null +++ b/tests/test_task_runners.py @@ -0,0 +1,137 @@ +import asyncio +import logging +import sys +from uuid import uuid4 + +import cloudpickle +import distributed +import pytest +from prefect.orion.schemas.core import TaskRun +from prefect.states import State +from prefect.task_runners import TaskConcurrencyType +from prefect.testing.fixtures import hosted_orion_api, use_hosted_orion # noqa: F401 +from prefect.testing.standard_test_suites import TaskRunnerStandardTestSuite + +from prefect_dask import DaskTaskRunner + + +@pytest.fixture(scope="session") +def event_loop(request): + """ + Redefine the event loop to support session/module-scoped fixtures; + see https://github.com/pytest-dev/pytest-asyncio/issues/68 + When running on Windows we need to use a non-default loop for subprocess support. + """ + if sys.platform == "win32" and sys.version_info >= (3, 8): + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + policy = asyncio.get_event_loop_policy() + + if sys.version_info < (3, 8) and sys.platform != "win32": + from prefect.utilities.compat import ThreadedChildWatcher + + # Python < 3.8 does not use a `ThreadedChildWatcher` by default which can + # lead to errors in tests as the previous default `SafeChildWatcher` is not + # compatible with threaded event loops. + policy.set_child_watcher(ThreadedChildWatcher()) + + loop = policy.new_event_loop() + + # configure asyncio logging to capture long running tasks + asyncio_logger = logging.getLogger("asyncio") + asyncio_logger.setLevel("WARNING") + asyncio_logger.addHandler(logging.StreamHandler()) + loop.set_debug(True) + loop.slow_callback_duration = 0.25 + + try: + yield loop + finally: + loop.close() + + # Workaround for failures in pytest_asyncio 0.17; + # see https://github.com/pytest-dev/pytest-asyncio/issues/257 + policy.set_event_loop(loop) + + +@pytest.fixture +def dask_task_runner_with_existing_cluster(use_hosted_orion): # noqa + """ + Generate a dask task runner that's connected to a local cluster + """ + with distributed.LocalCluster(n_workers=2) as cluster: + with distributed.Client(cluster) as client: + address = client.scheduler.address + yield DaskTaskRunner(address=address) + + +@pytest.fixture +def dask_task_runner_with_process_pool(): + yield DaskTaskRunner(cluster_kwargs={"processes": True}) + + +@pytest.fixture +def dask_task_runner_with_thread_pool(): + yield DaskTaskRunner(cluster_kwargs={"processes": False}) + + +@pytest.fixture +def default_dask_task_runner(): + yield DaskTaskRunner() + + +class TestDaskTaskRunner(TaskRunnerStandardTestSuite): + @pytest.fixture( + params=[ + default_dask_task_runner, + dask_task_runner_with_existing_cluster, + dask_task_runner_with_process_pool, + dask_task_runner_with_thread_pool, + ] + ) + def task_runner(self, request): + yield request.getfixturevalue( + request.param._pytestfixturefunction.name or request.param.__name__ + ) + + async def test_is_pickleable_after_start(self, task_runner): + """ + The task_runner must be picklable as it is attached to `PrefectFuture` objects + Reimplemented to set Dask client as default to allow unpickling + """ + task_runner.client_kwargs["set_as_default"] = True + async with task_runner.start(): + pickled = cloudpickle.dumps(task_runner) + unpickled = cloudpickle.loads(pickled) + assert isinstance(unpickled, type(task_runner)) + + @pytest.mark.parametrize("exception", [KeyboardInterrupt(), ValueError("test")]) + async def test_wait_captures_exceptions_as_crashed_state( + self, task_runner, exception + ): + """ + Dask wraps the exception, interrupts will result in "Cancelled" tasks + or "Killed" workers while normal errors will result in the raw error with Dask. + We care more about the crash detection and + lack of re-raise here than the equality of the exception. + """ + if task_runner.concurrency_type != TaskConcurrencyType.PARALLEL: + pytest.skip( + f"This will abort the run for " + f"{task_runner.concurrency_type} task runners." + ) + + task_run = TaskRun(flow_run_id=uuid4(), task_key="foo", dynamic_key="bar") + + async def fake_orchestrate_task_run(): + raise exception + + async with task_runner.start(): + future = await task_runner.submit( + task_run=task_run, run_fn=fake_orchestrate_task_run, run_kwargs={} + ) + + state = await task_runner.wait(future, 5) + assert state is not None, "wait timed out" + assert isinstance(state, State), "wait should return a state" + assert state.name == "Crashed"