diff --git a/airflow/api_connexion/endpoints/backfill_endpoint.py b/airflow/api_connexion/endpoints/backfill_endpoint.py index baafdeea4f992..a0e728c5bc464 100644 --- a/airflow/api_connexion/endpoints/backfill_endpoint.py +++ b/airflow/api_connexion/endpoints/backfill_endpoint.py @@ -32,8 +32,12 @@ backfill_collection_schema, backfill_schema, ) -from airflow.models.backfill import AlreadyRunningBackfill, Backfill, _create_backfill -from airflow.utils import timezone +from airflow.models.backfill import ( + AlreadyRunningBackfill, + Backfill, + _cancel_backfill, + _create_backfill, +) from airflow.utils.session import NEW_SESSION, provide_session from airflow.www.decorators import action_logging @@ -104,24 +108,6 @@ def unpause_backfill(*, backfill_id, session, **kwargs): return backfill_schema.dump(br) -@provide_session -@backfill_to_dag -@security.requires_access_dag("PUT") -@action_logging -def cancel_backfill(*, backfill_id, session, **kwargs): - br: Backfill = session.get(Backfill, backfill_id) - if br.completed_at is not None: - raise Conflict("Backfill is already completed.") - - br.completed_at = timezone.utcnow() - - # first, pause - if not br.is_paused: - br.is_paused = True - session.commit() - return backfill_schema.dump(br) - - @provide_session @backfill_to_dag @security.requires_access_dag("GET") @@ -155,3 +141,17 @@ def create_backfill( return backfill_schema.dump(backfill_obj) except AlreadyRunningBackfill: raise Conflict(f"There is already a running backfill for dag {dag_id}") + + +@provide_session +@backfill_to_dag +@security.requires_access_dag("PUT") +@action_logging +def cancel_backfill( + *, + backfill_id, + session: Session = NEW_SESSION, # used by backfill_to_dag decorator + **kwargs, +): + br = _cancel_backfill(backfill_id=backfill_id) + return backfill_schema.dump(br) diff --git a/airflow/models/backfill.py b/airflow/models/backfill.py index 6d3a8ee4fa922..db10c804aac0d 100644 --- a/airflow/models/backfill.py +++ b/airflow/models/backfill.py @@ -26,12 +26,13 @@ import logging from typing import TYPE_CHECKING -from sqlalchemy import Boolean, Column, ForeignKeyConstraint, Integer, UniqueConstraint, func, select +from sqlalchemy import Boolean, Column, ForeignKeyConstraint, Integer, UniqueConstraint, func, select, update from sqlalchemy.orm import relationship from sqlalchemy_jsonfield import JSONField -from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.exceptions import Conflict, NotFound from airflow.exceptions import AirflowException +from airflow.models import DagRun from airflow.models.base import Base, StringID from airflow.models.serialized_dag import SerializedDagModel from airflow.settings import json @@ -48,7 +49,11 @@ class AlreadyRunningBackfill(AirflowException): - """Raised when attempting to create backfill and one already active.""" + """ + Raised when attempting to create backfill and one already active. + + :meta private: + """ class Backfill(Base): @@ -172,7 +177,11 @@ def _create_backfill( session=session, ) except Exception: - dag.log.exception("something failed") + dag.log.exception( + "Error while attempting to create a dag run dag_id='%s' logical_date='%s'", + dag.dag_id, + info.logical_date, + ) session.rollback() session.add( BackfillDagRun( @@ -183,3 +192,31 @@ def _create_backfill( ) session.commit() return br + + +def _cancel_backfill(backfill_id) -> Backfill: + with create_session() as session: + b: Backfill = session.get(Backfill, backfill_id) + if b.completed_at is not None: + raise Conflict("Backfill is already completed.") + + b.completed_at = timezone.utcnow() + + # first, pause + if not b.is_paused: + b.is_paused = True + + session.commit() + + # now, let's mark all queued dag runs as failed + query = ( + update(DagRun) + .where( + DagRun.id.in_(select(BackfillDagRun.dag_run_id).where(BackfillDagRun.backfill_id == b.id)), + DagRun.state == DagRunState.QUEUED, + ) + .values(state=DagRunState.FAILED) + .execution_options(synchronize_session=False) + ) + session.execute(query) + return b diff --git a/tests/models/test_backfill.py b/tests/models/test_backfill.py index 9a845f86803e0..c45625db335de 100644 --- a/tests/models/test_backfill.py +++ b/tests/models/test_backfill.py @@ -24,7 +24,13 @@ from sqlalchemy import select from airflow.models import DagRun -from airflow.models.backfill import AlreadyRunningBackfill, Backfill, BackfillDagRun, _create_backfill +from airflow.models.backfill import ( + AlreadyRunningBackfill, + Backfill, + BackfillDagRun, + _cancel_backfill, + _create_backfill, +) from airflow.operators.python import PythonOperator from airflow.utils.state import DagRunState from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -71,7 +77,7 @@ def test_reverse_and_depends_on_past_fails(dep_on_past, dag_maker, session): @pytest.mark.parametrize("reverse", [True, False]) -def test_simple(reverse, dag_maker, session): +def test_create_backfill_simple(reverse, dag_maker, session): """ Verify simple case behavior. @@ -150,3 +156,38 @@ def test_active_dag_run(dag_maker, session): reverse=False, dag_run_conf={"this": "param"}, ) + + +def test_cancel_backfill(dag_maker, session): + """ + Queued runs should be marked *failed*. + Every other dag run should be left alone. + """ + with dag_maker(schedule="@daily") as dag: + PythonOperator(task_id="hi", python_callable=print) + b = _create_backfill( + dag_id=dag.dag_id, + from_date=pendulum.parse("2021-01-01"), + to_date=pendulum.parse("2021-01-05"), + max_active_runs=2, + reverse=False, + dag_run_conf={}, + ) + query = ( + select(DagRun) + .join(BackfillDagRun.dag_run) + .where(BackfillDagRun.backfill_id == b.id) + .order_by(BackfillDagRun.sort_ordinal) + ) + dag_runs = session.scalars(query).all() + dates = [str(x.logical_date.date()) for x in dag_runs] + expected_dates = ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04", "2021-01-05"] + assert dates == expected_dates + assert all(x.state == DagRunState.QUEUED for x in dag_runs) + dag_runs[0].state = "running" + session.commit() + _cancel_backfill(backfill_id=b.id) + session.expunge_all() + dag_runs = session.scalars(query).all() + states = [x.state for x in dag_runs] + assert states == ["running", "failed", "failed", "failed", "failed"]