Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Migrate from Orion #2

Merged
merged 19 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# dask output
dask-worker-space

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
26 changes: 16 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,25 @@ pip install prefect-dask
### Write and run a flow

```python
from prefect import flow
from prefect_dask.tasks import (
goodbye_prefect_dask,
hello_prefect_dask,
)
from prefect import flow, task
from prefect.task_runners import DaskTaskRunner

@task
def say_hello(name):
print(f"hello {name}")

@flow
def example_flow():
hello_prefect_dask
goodbye_prefect_dask
@task
def say_goodbye(name):
print(f"goodbye {name}")

example_flow()
@flow(task_runner=DaskTaskRunner())
def greetings(names):
for name in names:
say_hello(name)
say_goodbye(name)

if __name__ == "__main__":
greetings(["arthur", "trillian", "ford", "marvin"])
```

## Resources
Expand Down
1 change: 1 addition & 0 deletions docs/task_runners.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: prefect_dask.task_runners
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ plugins:

nav:
- Home: index.md
- Task Runners: task_runners.md
1 change: 1 addition & 0 deletions prefect_dask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import _version
from .task_runners import DaskTaskRunner # noqa

__version__ = _version.get_versions()["version"]
319 changes: 319 additions & 0 deletions prefect_dask/task_runners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
"""
Interface and implementations of the Dask Task Runner.
[Task Runners](/concepts/task-runners/) in Prefect are
responsible for managing the execution of Prefect task runs.
Generally speaking, users are not expected to interact with
task runners outside of configuring and initializing them for a flow.
Example:
>>> from prefect import flow, task
>>> from prefect.task_runners import SequentialTaskRunner
>>> from typing import List
>>>
>>> @task
>>> def say_hello(name):
... print(f"hello {name}")
>>>
>>> @task
>>> def say_goodbye(name):
... print(f"goodbye {name}")
>>>
>>> @flow(task_runner=SequentialTaskRunner())
>>> def greetings(names: List[str]):
... for name in names:
... say_hello(name)
... say_goodbye(name)
Switching to a `DaskTaskRunner`:
>>> from prefect.task_runners import DaskTaskRunner
>>> flow.task_runner = DaskTaskRunner()
>>> greetings(["arthur", "trillian", "ford", "marvin"])
hello arthur
goodbye arthur
hello trillian
hello ford
goodbye marvin
hello marvin
goodbye ford
goodbye trillian
"""

from contextlib import AsyncExitStack
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from uuid import UUID

import distributed
from prefect.futures import PrefectFuture
from prefect.orion.schemas.core import TaskRun
from prefect.orion.schemas.states import State
from prefect.states import exception_to_crashed_state
from prefect.task_runners import BaseTaskRunner, R, TaskConcurrencyType
from prefect.utilities.asyncio import A
from prefect.utilities.collections import visit_collection
from prefect.utilities.hashing import to_qualified_name
from prefect.utilities.importtools import import_object


class DaskTaskRunner(BaseTaskRunner):
"""
A parallel task_runner that submits tasks to the `dask.distributed` scheduler.
By default a temporary `distributed.LocalCluster` is created (and
subsequently torn down) within the `start()` contextmanager. To use a
different cluster class (e.g.
[`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can
specify `cluster_class`/`cluster_kwargs`.
Alternatively, if you already have a dask cluster running, you can provide
the address of the scheduler via the `address` kwarg.
!!! warning "Multiprocessing safety"
Note that, because the `DaskTaskRunner` uses multiprocessing, calls to flows
in scripts must be guarded with `if __name__ == "__main__":` or warnings will
be displayed.
Args:
address (string, optional): Address of a currently running dask
scheduler; if one is not provided, a temporary cluster will be
created in `DaskTaskRunner.start()`. Defaults to `None`.
cluster_class (string or callable, optional): The cluster class to use
when creating a temporary dask cluster. Can be either the full
class name (e.g. `"distributed.LocalCluster"`), or the class itself.
cluster_kwargs (dict, optional): Additional kwargs to pass to the
`cluster_class` when creating a temporary dask cluster.
adapt_kwargs (dict, optional): Additional kwargs to pass to `cluster.adapt`
when creating a temporary dask cluster. Note that adaptive scaling
is only enabled if `adapt_kwargs` are provided.
client_kwargs (dict, optional): Additional kwargs to use when creating a
[`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client).
Examples:
Using a temporary local dask cluster:
>>> from prefect import flow
>>> from prefect.task_runners import DaskTaskRunner
>>> @flow(task_runner=DaskTaskRunner)
>>> def my_flow():
>>> ...
Using a temporary cluster running elsewhere. Any Dask cluster class should
work, here we use [dask-cloudprovider](https://cloudprovider.dask.org):
>>> DaskTaskRunner(
>>> cluster_class="dask_cloudprovider.FargateCluster",
>>> cluster_kwargs={
>>> "image": "prefecthq/prefect:latest",
>>> "n_workers": 5,
>>> },
>>> )
Connecting to an existing dask cluster:
>>> DaskTaskRunner(address="192.0.2.255:8786")
"""

def __init__(
self,
address: str = None,
cluster_class: Union[str, Callable] = None,
cluster_kwargs: dict = None,
adapt_kwargs: dict = None,
client_kwargs: dict = None,
):
"""
Initializes keywords.
"""

# Validate settings and infer defaults
if address:
if cluster_class or cluster_kwargs or adapt_kwargs:
raise ValueError(
"Cannot specify `address` and "
"`cluster_class`/`cluster_kwargs`/`adapt_kwargs`"
)
else:
if isinstance(cluster_class, str):
cluster_class = import_object(cluster_class)
else:
cluster_class = cluster_class

# Create a copies of incoming kwargs since we may mutate them
cluster_kwargs = cluster_kwargs.copy() if cluster_kwargs else {}
adapt_kwargs = adapt_kwargs.copy() if adapt_kwargs else {}
client_kwargs = client_kwargs.copy() if client_kwargs else {}

# Update kwargs defaults
client_kwargs.setdefault("set_as_default", False)

# The user cannot specify async/sync themselves
if "asynchronous" in client_kwargs:
raise ValueError(
"`client_kwargs` cannot set `asynchronous`. "
"This option is managed by Prefect."
)
if "asynchronous" in cluster_kwargs:
raise ValueError(
"`cluster_kwargs` cannot set `asynchronous`. "
"This option is managed by Prefect."
)

# Store settings
self.address = address
self.cluster_class = cluster_class
self.cluster_kwargs = cluster_kwargs
self.adapt_kwargs = adapt_kwargs
self.client_kwargs = client_kwargs

# Runtime attributes
self._client: "distributed.Client" = None
self._cluster: "distributed.deploy.Cluster" = None
self._dask_futures: Dict[UUID, "distributed.Future"] = {}

super().__init__()

@property
def concurrency_type(self) -> TaskConcurrencyType:
"""
Set concurrency type.
"""
return (
TaskConcurrencyType.PARALLEL
if self.cluster_kwargs.get("processes")
else TaskConcurrencyType.CONCURRENT
)

async def submit(
self,
task_run: TaskRun,
run_fn: Callable[..., Awaitable[State[R]]],
run_kwargs: Dict[str, Any],
asynchronous: A = True,
) -> PrefectFuture[R, A]:
"""
Submit task run.
"""
if not self._started:
raise RuntimeError(
"The task runner must be started before submitting work."
)

# Cast Prefect futures to Dask futures where possible to optimize Dask task
# scheduling
run_kwargs = await self._optimize_futures(run_kwargs)

self._dask_futures[task_run.id] = self._client.submit(
run_fn,
# Dask displays the text up to the first '-' as the name, include the
# task run id to ensure the key is unique.
key=f"{task_run.name}-{task_run.id.hex}",
# Dask defaults to treating functions are pure, but we set this here for
# explicit expectations. If this task run is submitted to Dask twice, the
# result of the first run should be returned. Subsequent runs would return
# `Abort` exceptions if they were submitted again.
pure=True,
**run_kwargs,
)

return PrefectFuture(
task_run=task_run, task_runner=self, asynchronous=asynchronous
)

def _get_dask_future(self, prefect_future: PrefectFuture) -> "distributed.Future":
"""
Retrieve the dask future corresponding to a Prefect future.
The Dask future is for the `run_fn`, which should return a `State`.
"""
return self._dask_futures[prefect_future.run_id]

async def _optimize_futures(self, expr):
"""
Optimizes future.
"""

async def visit_fn(expr):
"""
Visits the fn.
"""
if isinstance(expr, PrefectFuture):
dask_future = self._dask_futures.get(expr.run_id)
if dask_future is not None:
return dask_future
# Fallback to return the expression unaltered
return expr

return await visit_collection(expr, visit_fn=visit_fn, return_data=True)

async def wait(
self,
prefect_future: PrefectFuture,
timeout: float = None,
) -> Optional[State]:
"""
Wait for task to finish.
"""
future = self._get_dask_future(prefect_future)
try:
return await future.result(timeout=timeout)
except self._distributed.TimeoutError:
return None
except BaseException as exc:
return exception_to_crashed_state(exc)

@property
def _distributed(self) -> "distributed":
"""
Delayed import of `distributed` allowing configuration of the task runner
without the extra installed and improves `prefect` import times.
"""
global distributed

if distributed is None:
try:
import distributed
except ImportError as exc:
raise RuntimeError(
"Using the `DaskTaskRunner` requires `distributed` to be installed."
) from exc

return distributed

async def _start(self, exit_stack: AsyncExitStack):
"""
Start the task runner and prep for context exit.
- Creates a cluster if an external address is not set.
- Creates a client to connect to the cluster.
- Pushes a call to wait for all running futures to complete on exit.
"""
if self.address:
self.logger.info(
f"Connecting to an existing Dask cluster at {self.address}"
)
connect_to = self.address
else:
self.cluster_class = self.cluster_class or self._distributed.LocalCluster

self.logger.info(
f"Creating a new Dask cluster with "
f"`{to_qualified_name(self.cluster_class)}`"
)
connect_to = self._cluster = await exit_stack.enter_async_context(
self.cluster_class(asynchronous=True, **self.cluster_kwargs)
)
if self.adapt_kwargs:
self._cluster.adapt(**self.adapt_kwargs)

self._client = await exit_stack.enter_async_context(
self._distributed.Client(
connect_to, asynchronous=True, **self.client_kwargs
)
)

if self._client.dashboard_link:
self.logger.info(
f"The Dask dashboard is available at {self._client.dashboard_link}",
)

def __getstate__(self):
"""
Allow the `DaskTaskRunner` to be serialized by dropping
the `distributed.Client`, which contains locks.
Must be deserialized on a dask worker.
"""
data = self.__dict__.copy()
data.update({k: None for k in {"_client", "_cluster"}})
return data

def __setstate__(self, data: dict):
"""
Restore the `distributed.Client` by loading the client on a dask worker.
"""
self.__dict__.update(data)
self._client = self._distributed.get_client()
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
prefect>=2.0a13
prefect>=2.0a13
dask==2022.2.0; python_version < '3.8'
dask>=2022.5.0; python_version >= '3.8'
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ show_missing = True

[tool:pytest]
asyncio_mode = auto

markers =
service(arg): a service integration test. For example 'docker'
enable_orion_handler: by default, sending logs to the API is disabled. Tests marked with this use the handler.
Empty file added tests/__init__.py
Empty file.
Loading