Skip to content

Commit de1a4cc

Browse files
authored
Add support for Kueue scheduling options to RayJob runner (#24)
* feat(rayjob): Support Kueue scheduling options in RayJob runner * chore: Separate out `util` module into package
1 parent 214cbce commit de1a4cc

File tree

11 files changed

+155
-126
lines changed

11 files changed

+155
-126
lines changed

src/jobs/job.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from jobs.assembler.renderers import RENDERERS
1717
from jobs.image import Image
1818
from jobs.types import AnyPath, K8sResourceKind
19-
from jobs.util import remove_none_values, run_command, to_rational
19+
from jobs.utils.helpers import remove_none_values
20+
from jobs.utils.math import to_rational
21+
from jobs.utils.processes import run_command
2022

2123

2224
class BuildMode(enum.Enum):

src/jobs/runner/docker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jobs import Image, Job
77
from jobs.job import DockerResourceOptions
88
from jobs.runner.base import Runner, _make_executor_command
9-
from jobs.util import remove_none_values
9+
from jobs.utils.helpers import remove_none_values
1010

1111

1212
class DockerRunner(Runner):

src/jobs/runner/kueue.py

+5-55
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
from jobs import Image, Job
66
from jobs.runner.base import Runner, _make_executor_command
77
from jobs.types import K8sResourceKind
8-
from jobs.util import (
9-
KubernetesNamespaceMixin,
10-
remove_none_values,
11-
sanitize_rfc1123_domain_name,
12-
)
8+
from jobs.utils.kubernetes import KubernetesNamespaceMixin, sanitize_rfc1123_domain_name
9+
from jobs.utils.kueue import kueue_scheduling_labels
1310

1411

1512
class KueueRunner(Runner, KubernetesNamespaceMixin):
@@ -19,61 +16,14 @@ def __init__(self, **kwargs: str) -> None:
1916
self._queue = kwargs.get("local_queue", "user-queue")
2017

2118
def _make_job_crd(self, job: Job, image: Image, namespace: str) -> client.V1Job:
22-
def _assert_kueue_localqueue(name: str) -> bool:
23-
try:
24-
_ = client.CustomObjectsApi().get_namespaced_custom_object(
25-
"kueue.x-k8s.io",
26-
"v1beta1",
27-
namespace,
28-
"localqueues",
29-
name,
30-
)
31-
return True
32-
except client.exceptions.ApiException:
33-
return False
34-
35-
def _assert_kueue_workloadpriorityclass(name: str) -> bool:
36-
try:
37-
_ = client.CustomObjectsApi().get_cluster_custom_object(
38-
"kueue.x-k8s.io",
39-
"v1beta1",
40-
"workloadpriorityclasses",
41-
name,
42-
)
43-
return True
44-
except client.exceptions.ApiException:
45-
return False
46-
4719
if not job.options:
4820
raise ValueError("Job options must be specified")
4921

50-
sched_opts = job.options.scheduling
51-
if sched_opts:
52-
if queue := sched_opts.queue_name:
53-
if not _assert_kueue_localqueue(queue):
54-
raise ValueError(
55-
f"Specified Kueue local queue does not exist: {queue!r}"
56-
)
57-
if pc := sched_opts.priority_class:
58-
if not _assert_kueue_workloadpriorityclass(pc):
59-
raise ValueError(
60-
f"Specified Kueue workload priority class does not exist: {pc!r}"
61-
)
22+
scheduling_labels = kueue_scheduling_labels(job, self.namespace)
6223

6324
metadata = client.V1ObjectMeta(
6425
generate_name=sanitize_rfc1123_domain_name(job.name),
65-
labels=remove_none_values(
66-
{
67-
"kueue.x-k8s.io/queue-name": (
68-
sched_opts.queue_name
69-
if sched_opts and sched_opts.queue_name
70-
else None
71-
),
72-
"kueue.x-k8s.io/priority-class": (
73-
sched_opts.priority_class if sched_opts else None
74-
),
75-
}
76-
),
26+
labels=scheduling_labels,
7727
)
7828

7929
# Job container
@@ -87,7 +37,7 @@ def _assert_kueue_workloadpriorityclass(name: str) -> bool:
8737
"requests": res.to_kubernetes(kind=K8sResourceKind.REQUESTS),
8838
"limits": res.to_kubernetes(kind=K8sResourceKind.LIMITS),
8939
}
90-
if job.options and (res := job.options.resources)
40+
if (res := job.options.resources)
9141
else None
9242
),
9343
)

src/jobs/runner/ray.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from jobs.job import RayResourceOptions
1818
from jobs.runner.base import Runner, _make_executor_command
1919
from jobs.types import K8sResourceKind, NoOptions
20-
from jobs.util import KubernetesNamespaceMixin, sanitize_rfc1123_domain_name
20+
from jobs.utils.kubernetes import KubernetesNamespaceMixin, sanitize_rfc1123_domain_name
21+
from jobs.utils.kueue import kueue_scheduling_labels
2122

2223

2324
class RayClusterRunner(Runner):
@@ -93,8 +94,7 @@ class RayJobRunner(Runner, KubernetesNamespaceMixin):
9394
def __init__(self, **kwargs):
9495
super().__init__(**kwargs)
9596

96-
@staticmethod
97-
def _create_ray_job(job: Job, image: Image) -> dict:
97+
def _create_ray_job(self, job: Job, image: Image) -> dict:
9898
"""Create a ``RayJob`` Kubernetes resource for the Kuberay operator."""
9999

100100
if job.options is None:
@@ -104,6 +104,8 @@ def _create_ray_job(job: Job, image: Image) -> dict:
104104
if not res_opts:
105105
raise ValueError("Job resource options must be set")
106106

107+
scheduling_labels = kueue_scheduling_labels(job, self.namespace)
108+
107109
runtime_env = {
108110
"working_dir": "/home/ray/app",
109111
}
@@ -115,6 +117,7 @@ def _create_ray_job(job: Job, image: Image) -> dict:
115117
"kind": "RayJob",
116118
"metadata": {
117119
"name": sanitize_rfc1123_domain_name(job_id),
120+
"labels": scheduling_labels,
118121
},
119122
"spec": {
120123
"jobId": job_id,

src/jobs/utils/__init__.py

Whitespace-only changes.

src/jobs/utils/helpers.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Mapping, TypeVar, cast
4+
5+
T = TypeVar("T", bound=Mapping[str, Any])
6+
7+
8+
def remove_none_values(d: T) -> T:
9+
"""Remove all keys with a ``None`` value from a dict."""
10+
filtered_dict = {k: v for k, v in d.items() if v is not None}
11+
return cast(T, filtered_dict)

src/jobs/utils/kubernetes.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
import kubernetes
4+
5+
6+
def sanitize_rfc1123_domain_name(s: str) -> str:
7+
"""Sanitize a string to be compliant with RFC 1123 domain name
8+
9+
Note: Any invalid characters are replaced with dashes."""
10+
11+
# TODO: This is obviously wildly incomplete
12+
return s.replace("_", "-")
13+
14+
15+
class KubernetesNamespaceMixin:
16+
"""Determine the desired or current Kubernetes namespace."""
17+
18+
def __init__(self, **kwargs):
19+
kubernetes.config.load_config()
20+
self._namespace: str | None = kwargs.get("namespace")
21+
22+
@property
23+
def namespace(self) -> str:
24+
_, active_context = kubernetes.config.list_kube_config_contexts()
25+
current_namespace = active_context["context"].get("namespace")
26+
return self._namespace or current_namespace

src/jobs/utils/kueue.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Mapping, cast
2+
3+
from kubernetes import client
4+
5+
from jobs.job import Job
6+
from jobs.utils.helpers import remove_none_values
7+
8+
9+
def assert_kueue_localqueue(namespace: str, name: str) -> bool:
10+
"""Check the existence of a Kueue `LocalQueue` in a namespace."""
11+
try:
12+
_ = client.CustomObjectsApi().get_namespaced_custom_object(
13+
"kueue.x-k8s.io",
14+
"v1beta1",
15+
namespace,
16+
"localqueues",
17+
name,
18+
)
19+
return True
20+
except client.exceptions.ApiException:
21+
return False
22+
23+
24+
def assert_kueue_workloadpriorityclass(name: str) -> bool:
25+
"""Check the existence of a Kueue `WorkloadPriorityClass` in the cluster."""
26+
try:
27+
_ = client.CustomObjectsApi().get_cluster_custom_object(
28+
"kueue.x-k8s.io",
29+
"v1beta1",
30+
"workloadpriorityclasses",
31+
name,
32+
)
33+
return True
34+
except client.exceptions.ApiException:
35+
return False
36+
37+
38+
def kueue_scheduling_labels(job: Job, namespace: str) -> Mapping[str, str]:
39+
"""Determine the Kubernetes labels controlling Kueue features such as queues and priority for a job."""
40+
41+
if not job.options:
42+
return {}
43+
if not (sched_opts := job.options.scheduling):
44+
return {}
45+
46+
if queue := sched_opts.queue_name:
47+
if not assert_kueue_localqueue(namespace, queue):
48+
raise ValueError(f"Specified Kueue local queue does not exist: {queue!r}")
49+
if pc := sched_opts.priority_class:
50+
if not assert_kueue_workloadpriorityclass(pc):
51+
raise ValueError(
52+
f"Specified Kueue workload priority class does not exist: {pc!r}"
53+
)
54+
55+
return cast(
56+
Mapping[str, str],
57+
remove_none_values(
58+
{
59+
"kueue.x-k8s.io/queue-name": (
60+
sched_opts.queue_name if sched_opts else None
61+
),
62+
"kueue.x-k8s.io/priority-class": (
63+
sched_opts.priority_class if sched_opts else None
64+
),
65+
}
66+
),
67+
)

src/jobs/utils/math.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
import re
4+
5+
6+
def to_rational(s: str) -> float:
7+
"""Convert a number with optional SI/binary unit to floating-point"""
8+
9+
matches = re.match(r"(?P<magnitude>[+\-]?\d*[.,]?\d+)(?P<suffix>[a-zA-Z]*)", s)
10+
if not matches:
11+
raise ValueError(f"Could not parse {s}")
12+
magnitude = float(matches.group("magnitude"))
13+
suffix = matches.group("suffix")
14+
15+
factor = {
16+
# SI / Metric
17+
"m": 1e-3,
18+
"k": 1e3,
19+
"M": 1e6,
20+
"G": 1e9,
21+
"T": 1e12,
22+
# Binary
23+
"Ki": 2**10,
24+
"Mi": 2**20,
25+
"Gi": 2**30,
26+
"Ti": 2**40,
27+
# default
28+
"": 1.0,
29+
}.get(suffix)
30+
if factor is None:
31+
raise ValueError(f"unknown unit suffix: {suffix}")
32+
33+
return factor * magnitude

src/jobs/util.py src/jobs/utils/processes.py

+1-65
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,15 @@
11
from __future__ import annotations
22

3-
import re
43
import shlex
54
import subprocess
65
import sys
76
import threading
87
import time
98
from io import TextIOBase
10-
from typing import Any, Iterable, Mapping, TextIO, TypeVar, cast
11-
12-
import kubernetes
9+
from typing import Iterable, Mapping, TextIO
1310

1411
from jobs.types import AnyPath
1512

16-
T = TypeVar("T", bound=Mapping[str, Any])
17-
18-
19-
def to_rational(s: str) -> float:
20-
"""Convert a number with optional SI/binary unit to floating-point"""
21-
22-
matches = re.match(r"(?P<magnitude>[+\-]?\d*[.,]?\d+)(?P<suffix>[a-zA-Z]*)", s)
23-
if not matches:
24-
raise ValueError(f"Could not parse {s}")
25-
magnitude = float(matches.group("magnitude"))
26-
suffix = matches.group("suffix")
27-
28-
factor = {
29-
# SI / Metric
30-
"m": 1e-3,
31-
"k": 1e3,
32-
"M": 1e6,
33-
"G": 1e9,
34-
"T": 1e12,
35-
# Binary
36-
"Ki": 2**10,
37-
"Mi": 2**20,
38-
"Gi": 2**30,
39-
"Ti": 2**40,
40-
# default
41-
"": 1.0,
42-
}.get(suffix)
43-
if factor is None:
44-
raise ValueError(f"unknown unit suffix: {suffix}")
45-
46-
return factor * magnitude
47-
48-
49-
def remove_none_values(d: T) -> T:
50-
"""Remove all keys with a ``None`` value from a dict."""
51-
filtered_dict = {k: v for k, v in d.items() if v is not None}
52-
return cast(T, filtered_dict)
53-
54-
55-
def sanitize_rfc1123_domain_name(s: str) -> str:
56-
"""Sanitize a string to be compliant with RFC 1123 domain name
57-
58-
Note: Any invalid characters are replaced with dashes."""
59-
60-
# TODO: This is obviously wildly incomplete
61-
return s.replace("_", "-")
62-
6313

6414
def run_command(
6515
command: str,
@@ -157,17 +107,3 @@ def _reader(
157107
read_stderr.join()
158108

159109
return process.returncode, stdout, stderr, output
160-
161-
162-
class KubernetesNamespaceMixin:
163-
"""Determine the desired or current Kubernetes namespace."""
164-
165-
def __init__(self, **kwargs):
166-
kubernetes.config.load_config()
167-
self._namespace: str | None = kwargs.get("namespace")
168-
169-
@property
170-
def namespace(self) -> str:
171-
_, active_context = kubernetes.config.list_kube_config_contexts()
172-
current_namespace = active_context["context"].get("namespace")
173-
return self._namespace or current_namespace

tests/unit/test_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

3-
from jobs.util import remove_none_values, to_rational
3+
from jobs.utils.helpers import remove_none_values
4+
from jobs.utils.math import to_rational
45

56

67
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)