Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add project creation API endpoint #139

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/src/jobq_server/__main__.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@

from jobq_server.config import settings
from jobq_server.db import check_migrations, get_engine, upgrade_migrations
from jobq_server.routers import jobs
from jobq_server.routers import jobs, projects


@asynccontextmanager
@@ -36,6 +36,7 @@ async def lifespan(app: FastAPI):
)

app.include_router(jobs.router, prefix="/jobs")
app.include_router(projects.router, prefix="/projects")


@app.get("/health", include_in_schema=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Initial schema

Revision ID: 2837c7c54f35
Revises:
Create Date: 2024-10-31 11:27:32.242586

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '2837c7c54f35'
down_revision = None
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('project',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('cluster_queue', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('local_queue', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('namespace', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('id', sa.Uuid(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_project_description'), 'project', ['description'], unique=False)
op.create_index(op.f('ix_project_name'), 'project', ['name'], unique=True)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_project_name'), table_name='project')
op.drop_index(op.f('ix_project_description'), table_name='project')
op.drop_table('project')
# ### end Alembic commands ###
25 changes: 23 additions & 2 deletions backend/src/jobq_server/db.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from threading import Lock
from uuid import UUID, uuid4

from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import Engine
from sqlmodel import SQLModel as SQLModel
from sqlmodel import create_engine
from sqlmodel import Field, SQLModel, create_engine

from jobq_server.config import settings

@@ -50,3 +50,24 @@ def upgrade_migrations():

alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")


# --- PROJECT
class ProjectBase(SQLModel):
name: str = Field(index=True, unique=True)
description: str | None = Field(None, index=True)
cluster_queue: str | None = Field(None)
local_queue: str | None = Field(None)
namespace: str | None = Field(None)


class Project(ProjectBase, table=True):
id: UUID = Field(default_factory=uuid4, primary_key=True)


class ProjectCreate(ProjectBase):
pass


class ProjectPublic(ProjectBase):
id: UUID
10 changes: 9 additions & 1 deletion backend/src/jobq_server/dependencies/__init__.py
Original file line number Diff line number Diff line change
@@ -7,13 +7,21 @@
from jobq_server.db import get_engine
from jobq_server.models import JobId
from jobq_server.services.k8s import KubernetesService
from jobq_server.services.kueue import KueueService
from jobq_server.utils.kueue import KueueWorkload


def k8s_service() -> KubernetesService:
return KubernetesService()


KubernetesDep = Annotated[KubernetesService, Depends(k8s_service)]


def kueue_service(k8s: KubernetesDep) -> KueueService:
return KueueService(k8s)


def managed_workload(
k8s: Annotated[KubernetesService, Depends(k8s_service)],
uid: JobId,
@@ -31,5 +39,5 @@ def get_session() -> Generator[Session, None, None]:


ManagedWorkload = Annotated[KueueWorkload, Depends(managed_workload)]
Kubernetes = Annotated[KubernetesService, Depends(k8s_service)]
KueueDep = Annotated[KueueService, Depends(kueue_service)]
DBSessionDep = Annotated[Session, Depends(get_session)]
10 changes: 5 additions & 5 deletions backend/src/jobq_server/routers/jobs.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from fastapi.responses import StreamingResponse
from jobq import Job

from jobq_server.dependencies import Kubernetes, ManagedWorkload
from jobq_server.dependencies import KubernetesDep, ManagedWorkload
from jobq_server.exceptions import PodNotReadyError
from jobq_server.models import (
CreateJobModel,
@@ -28,7 +28,7 @@
@router.post("")
async def submit_job(
opts: CreateJobModel,
k8s: Kubernetes,
k8s: KubernetesDep,
) -> WorkloadIdentifier:
# FIXME: Having to define a function just to set the job name is ugly
def job_fn(): ...
@@ -72,7 +72,7 @@ async def status(
@router.get("/{uid}/logs")
async def logs(
workload: ManagedWorkload,
k8s: Kubernetes,
k8s: KubernetesDep,
params: Annotated[LogOptions, Depends(make_dependable(LogOptions))],
):
try:
@@ -164,7 +164,7 @@ async def stream_response(
async def stop_workload(
uid: JobId,
workload: ManagedWorkload,
k8s: Kubernetes,
k8s: KubernetesDep,
):
try:
workload.stop(k8s)
@@ -182,7 +182,7 @@ async def stop_workload(

@router.get("", response_model_exclude_unset=True)
async def list_jobs(
k8s: Kubernetes,
k8s: KubernetesDep,
include_metadata: Annotated[bool, Query()] = False,
) -> list[ListWorkloadModel]:
workloads = k8s.list_workloads()
83 changes: 83 additions & 0 deletions backend/src/jobq_server/routers/projects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging

from fastapi import APIRouter
from kubernetes import client
from sqlmodel import select

from jobq_server.db import Project, ProjectCreate, ProjectPublic
from jobq_server.dependencies import DBSessionDep, KubernetesDep, KueueDep
from jobq_server.utils.kueue import ClusterQueue, ClusterQueueSpec

FINALIZER = "jobq.example.com/project-protection"
router = APIRouter()


@router.get("/")
async def list_projects(db: DBSessionDep):
return db.exec(select(Project)).all()


@router.post("/", status_code=201)
async def create_project(
project: ProjectCreate,
db: DBSessionDep,
k8s: KubernetesDep,
kueue: KueueDep,
) -> ProjectPublic:
# Create namespace if it doesn't exist
ns, ns_created = k8s.ensure_namespace(project.namespace)
if ns_created:
logging.info(f"Created Kubernetes namespace {ns.metadata.name}")

# Create cluster queue if it doesn't exist
cluster_queue = kueue.get_cluster_queue(project.cluster_queue)
if cluster_queue is None:
default_spec = {
"namespaceSelector": {},
"preemption": {
"reclaimWithinCohort": "Any",
"borrowWithinCohort": {
"policy": "LowerPriority",
"maxPriorityThreshold": 100,
},
"withinClusterQueue": "LowerPriority",
},
"resourceGroups": [
{
"coveredResources": ["cpu", "memory"],
"flavors": [
{
"name": "default-flavor",
"resources": [
{"name": "cpu", "nominalQuota": 4},
{"name": "memory", "nominalQuota": 6},
],
}
],
}
],
}
cluster_queue = ClusterQueue(
metadata=client.V1ObjectMeta(name=project.cluster_queue),
spec=ClusterQueueSpec.model_validate(default_spec),
)
kueue.create_cluster_queue(cluster_queue)
logging.info(f"Created cluster queue {project.cluster_queue!r}")

# Create local queue if it doesn't exist
_, local_queue_created = kueue.ensure_local_queue(
project.local_queue, project.namespace, project.cluster_queue
)
if local_queue_created:
logging.info(
f"Created user queue {project.local_queue!r} in namespace {project.namespace!r}"
)

k8s.add_finalizer(ns, FINALIZER)
# TODO: Apply finalizers to Kubernetes resources to prevent deletion while the project exists

db_obj = Project.model_validate(project)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
43 changes: 42 additions & 1 deletion backend/src/jobq_server/services/k8s.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@ def __init__(self):
)
config.load_kube_config()
self._in_cluster = False

self._core_v1_api = client.CoreV1Api()

@property
@@ -129,3 +128,45 @@ def list_workloads(self, namespace: str | None = None) -> list[KueueWorkload]:
KueueWorkload.model_validate(workload)
for workload in workloads.get("items", [])
]

def ensure_namespace(self, name: str) -> tuple[client.V1Namespace, bool]:
"""Create or look up a namespace by name

Returns
-------
tuple[client.V1Namespace, bool]
The namespace object and a boolean indicating whether it was created
"""

try:
return self._core_v1_api.read_namespace(name), False
except client.ApiException as e:
if e.status == 404:
return self._core_v1_api.create_namespace(
client.V1Namespace(metadata=client.V1ObjectMeta(name=name))
), True
raise

def add_finalizer(
self,
resource: client.V1Namespace | client.V1CustomResourceDefinition,
finalizer: str,
) -> None:
"""Add a finalizer to a Kubernetes resource"""
if resource.metadata.finalizers is None:
resource.metadata.finalizers = []
if finalizer not in resource.metadata.finalizers:
resource.metadata.finalizers.append(finalizer)

if isinstance(resource, client.V1Namespace):
self._core_v1_api.replace_namespace(resource.metadata.name, resource)
else:
api = client.CustomObjectsApi()
api.replace_namespaced_custom_object(
group=resource.api_version.split("/")[0],
version=resource.api_version.split("/")[1],
namespace=resource.metadata.namespace,
plural=resource.kind.lower() + "s",
name=resource.metadata.name,
body=resource,
)
Loading