diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 0789b4ee88d82..60c77c7a37feb 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -512,7 +512,6 @@ def string_lower_type(val): ("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true" ) ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)") -ARG_ERROR_FILE = Arg(("--error-file",), help="File to store task failure error") ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS) ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg") ARG_MAP_INDEX = Arg(('--map-index',), type=int, default=-1, help="Mapped task index") @@ -1264,7 +1263,6 @@ class GroupCommand(NamedTuple): ARG_PICKLE, ARG_JOB_ID, ARG_INTERACTIVE, - ARG_ERROR_FILE, ARG_SHUT_DOWN_LOGGING, ARG_MAP_INDEX, ), diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index ea20ebb64615e..78e8fc20f6e49 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -45,6 +45,7 @@ from airflow.utils import cli as cli_utils from airflow.utils.cli import ( get_dag, + get_dag_by_deserialization, get_dag_by_file_location, get_dag_by_pickle, get_dags, @@ -258,7 +259,6 @@ def _run_raw_task(args, ti: TaskInstance) -> None: mark_success=args.mark_success, job_id=args.job_id, pool=args.pool, - error_file=args.error_file, ) @@ -357,7 +357,10 @@ def task_run(args, dag=None): print(f'Loading pickle id: {args.pickle}') dag = get_dag_by_pickle(args.pickle) elif not dag: - dag = get_dag(args.subdir, args.dag_id) + if args.local: + dag = get_dag_by_deserialization(args.dag_id) + else: + dag = get_dag(args.subdir, args.dag_id) else: # Use DAG from parameter pass diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 469b55cfeb621..34821ffb92010 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -607,7 +607,7 @@ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index) # TODO: Use simple_ti to improve performance here in the future ti.refresh_from_db() - ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE) + ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE) self.log.info('Executed failure callback for %s in state %s', ti, ti.state) @provide_session diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index 865186dd18ec7..7a2dddec79358 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -66,7 +66,6 @@ def sync(self) -> None: self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED) ti.set_state(State.FAILED) self.change_state(ti.key, State.FAILED) - ti._run_finished_callback() continue task_succeeded = self._run_task(ti) @@ -78,12 +77,10 @@ def _run_task(self, ti: TaskInstance) -> bool: params = self.tasks_params.pop(ti.key, {}) ti._run_raw_task(job_id=ti.job_id, **params) self.change_state(key, State.SUCCESS) - ti._run_finished_callback() return True except Exception as e: ti.set_state(State.FAILED) self.change_state(key, State.FAILED) - ti._run_finished_callback(error=e) self.log.exception("Failed to execute task: %s.", str(e)) return False diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 6695566cc72f0..c5b98c2a8df2a 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -261,7 +261,7 @@ def _manage_executor_state( f"{ti.state}. Was the task killed externally? Info: {info}" ) self.log.error(msg) - ti.handle_failure_with_callback(error=msg) + ti.handle_failure(error=msg) continue if ti.state not in self.STATES_COUNT_AS_RUNNING: # Don't use ti.task; if this task is mapped, that attribute diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 9b2c3510ebfd0..5711342e04d97 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -73,6 +73,8 @@ def __init__( # terminate multiple times self.terminating = False + self._state_change_checks = 0 + super().__init__(*args, **kwargs) def _execute(self): @@ -84,7 +86,6 @@ def signal_handler(signum, frame): self.log.error("Received SIGTERM. Terminating subprocesses") self.task_runner.terminate() self.handle_task_exit(128 + signum) - return signal.signal(signal.SIGTERM, signal_handler) @@ -106,13 +107,15 @@ def signal_handler(signum, frame): heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold') - # task callback invocation happens either here or in - # self.heartbeat() instead of taskinstance._run_raw_task to - # avoid race conditions - # - # When self.terminating is set to True by heartbeat_callback, this - # loop should not be restarted. Otherwise self.handle_task_exit - # will be invoked and we will end up with duplicated callbacks + # LocalTaskJob should not run callbacks, which are handled by TaskInstance._run_raw_task + # 1, LocalTaskJob does not parse DAG, thus cannot run callbacks + # 2, The run_as_user of LocalTaskJob is likely not same as the TaskInstance._run_raw_task. + # When run_as_user is specified, the process owner of the LocalTaskJob must be sudoable. + # It is not secure to run callbacks with sudoable users. + + # If _run_raw_task receives SIGKILL, scheduler will mark it as zombie and invoke callbacks + # If LocalTaskJob receives SIGTERM, LocalTaskJob passes SIGTERM to _run_raw_task + # If the state of task_instance is changed, LocalTaskJob sends SIGTERM to _run_raw_task while not self.terminating: # Monitor the task to see if it's done. Wait in a syscall # (`os.wait`) for as long as possible so we notice the @@ -150,26 +153,18 @@ def signal_handler(signum, frame): self.on_kill() def handle_task_exit(self, return_code: int) -> None: - """Handle case where self.task_runner exits by itself or is externally killed""" + """ + Handle case where self.task_runner exits by itself or is externally killed + + Dont run any callbacks + """ # Without setting this, heartbeat may get us self.terminating = True self.log.info("Task exited with return code %s", return_code) - self.task_instance.refresh_from_db() - if self.task_instance.state == State.RUNNING: - # This is for a case where the task received a SIGKILL - # while running or the task runner received a sigterm - self.task_instance.handle_failure(error=None) - # We need to check for error file - # in case it failed due to runtime exception/error - error = None - if self.task_instance.state != State.SUCCESS: - error = self.task_runner.deserialize_run_error() - self.task_instance._run_finished_callback(error=error) if not self.task_instance.test_mode: if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): self._run_mini_scheduler_on_child_tasks() - self._update_dagrun_state_for_paused_dag() def on_kill(self): self.task_runner.terminate() @@ -217,19 +212,16 @@ def heartbeat_callback(self, session=None): dagrun_timeout = ti.task.dag.dagrun_timeout if dagrun_timeout and execution_time > dagrun_timeout: self.log.warning("DagRun timed out after %s.", str(execution_time)) - self.log.warning( - "State of this instance has been externally set to %s. Terminating instance.", ti.state - ) - self.task_runner.terminate() - if ti.state == State.SUCCESS: - error = None - else: - # if ti.state is not set by taskinstance.handle_failure, then - # error file will not be populated and it must be updated by - # external source such as web UI - error = self.task_runner.deserialize_run_error() or "task marked as failed externally" - ti._run_finished_callback(error=error) - self.terminating = True + + # potential race condition, the _run_raw_task commits `success` or other state + # but task_runner does not exit right away due to slow process shutdown or any other reasons + # let's do a throttle here, if the above case is true, the handle_task_exit will handle it + if self._state_change_checks >= 1: # defer to next round of heartbeat + self.log.warning( + "State of this instance has been externally set to %s. Terminating instance.", ti.state + ) + self.terminating = True + self._state_change_checks += 1 @provide_session @Sentry.enrich_errors @@ -282,19 +274,6 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: ) session.rollback() - @provide_session - def _update_dagrun_state_for_paused_dag(self, session=None): - """ - Checks for paused dags with DagRuns in the running state and - update the DagRun state if possible - """ - dag = self.task_instance.task.dag - if dag.get_is_paused(): - dag_run = self.task_instance.get_dagrun(session=session) - if dag_run: - dag_run.dag = dag - dag_run.update_state(session=session, execute_callbacks=True) - @staticmethod def _enable_task_listeners(): """ diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index ac1d25833b5aa..853c91e79fb9b 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -149,6 +149,7 @@ def __init__( self.processor_agent: Optional[DagFileProcessorAgent] = None self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False) + self._paused_dag_without_running_dagruns: Set = set() if conf.getboolean('smart_sensor', 'use_smart_sensor'): compatible_sensors = set( @@ -764,6 +765,26 @@ def _execute(self) -> None: self.log.exception("Exception when executing DagFileProcessorAgent.end") self.log.info("Exited execute loop") + def _update_dag_run_state_for_paused_dags(self): + try: + paused_dag_ids = DagModel.get_all_paused_dag_ids() + for dag_id in paused_dag_ids: + if dag_id in self._paused_dag_without_running_dagruns: + continue + + dag = SerializedDagModel.get_dag(dag_id) + if dag is None: + continue + dag_runs = DagRun.find(dag_id=dag_id, state=State.RUNNING) + for dag_run in dag_runs: + dag_run.dag = dag + _, callback_to_run = dag_run.update_state(execute_callbacks=False) + if callback_to_run: + self._send_dag_callbacks_to_processor(dag, callback_to_run) + self._paused_dag_without_running_dagruns.add(dag_id) + except Exception as e: # should not fail the scheduler + self.log.exception('Failed to update dag run state for paused dags due to %s', str(e)) + def _run_scheduler_loop(self) -> None: """ The actual scheduler loop. The main steps in the loop are: @@ -809,6 +830,7 @@ def _run_scheduler_loop(self) -> None: conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0), self._find_zombies, ) + timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags) for loop_count in itertools.count(start=1): with Stats.timer() as timer: diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 527928adb915b..1ba4bd8e0faa7 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2762,6 +2762,15 @@ def get_dagmodel(dag_id, session=NEW_SESSION): def get_current(cls, dag_id, session=NEW_SESSION): return session.query(cls).filter(cls.dag_id == dag_id).first() + @staticmethod + @provide_session + def get_all_paused_dag_ids(session: Session = NEW_SESSION) -> Set[str]: + """Get a set of paused DAG ids""" + paused_dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_paused == expression.true()).all() + + paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids} + return paused_dag_ids + @provide_session def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): return get_last_dagrun( diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 23d04992d3c75..45424709a8421 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -250,6 +250,14 @@ def has_dag(cls, dag_id: str, session: Session = None) -> bool: """ return session.query(literal(True)).filter(cls.dag_id == dag_id).first() is not None + @classmethod + @provide_session + def get_dag(cls, dag_id: str, session: Session = None) -> Optional['SerializedDAG']: + row = cls.get(dag_id, session=session) + if row: + return row.dag + return None + @classmethod @provide_session def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagModel']: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 123b4ddea38a2..c745442a21d65 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -22,16 +22,13 @@ import math import operator import os -import pickle import signal import warnings from collections import defaultdict from datetime import datetime, timedelta from functools import partial -from tempfile import NamedTemporaryFile from types import TracebackType from typing import ( - IO, TYPE_CHECKING, Any, Callable, @@ -160,31 +157,6 @@ def set_current_context(context: Context) -> Iterator[Context]: ) -def load_error_file(fd: IO[bytes]) -> Optional[Union[str, Exception]]: - """Load and return error from error file""" - if fd.closed: - return None - fd.seek(0, os.SEEK_SET) - data = fd.read() - if not data: - return None - try: - return pickle.loads(data) - except Exception: - return "Failed to load task run error" - - -def set_error_file(error_file: str, error: Union[str, BaseException]) -> None: - """Write error into error file by path""" - with open(error_file, "wb") as fd: - try: - pickle.dump(error, fd) - except Exception: - # local class objects cannot be pickled, so we fallback - # to store the string representation instead - pickle.dump(str(error), fd) - - def clear_task_instances( tis, session, @@ -1415,7 +1387,6 @@ def _run_raw_task( test_mode: bool = False, job_id: Optional[str] = None, pool: Optional[str] = None, - error_file: Optional[str] = None, session=NEW_SESSION, ) -> None: """ @@ -1443,10 +1414,11 @@ def _run_raw_task( # Initialize final state counters at zero for state in State.task_states: Stats.incr(f'ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}', count=0) + + self.task = self.task.prepare_for_execution() + context = self.get_template_context(ignore_param_exceptions=False) try: if not mark_success: - self.task = self.task.prepare_for_execution() - context = self.get_template_context(ignore_param_exceptions=False) self._execute_task_with_callbacks(context, test_mode) if not test_mode: self.refresh_from_db(lock_for_update=True, session=session) @@ -1485,7 +1457,7 @@ def _run_raw_task( except (AirflowFailException, AirflowSensorTimeout) as e: # If AirflowFailException is raised, task should not retry. # If a sensor in reschedule mode reaches timeout, task should not retry. - self.handle_failure(e, test_mode, force_fail=True, error_file=error_file, session=session) + self.handle_failure(e, test_mode, context, force_fail=True, session=session) session.commit() raise except AirflowException as e: @@ -1500,11 +1472,11 @@ def _run_raw_task( session.commit() return else: - self.handle_failure(e, test_mode, error_file=error_file, session=session) + self.handle_failure(e, test_mode, context, session=session) session.commit() raise except (Exception, KeyboardInterrupt) as e: - self.handle_failure(e, test_mode, error_file=error_file, session=session) + self.handle_failure(e, test_mode, context, session=session) session.commit() raise finally: @@ -1515,6 +1487,12 @@ def _run_raw_task( self.end_date = timezone.utcnow() self._log_state() self.set_duration() + + # run on_success_callback before db committing + # otherwise, the LocalTaskJob sees the state is changed to `success`, + # but the task_runner is still running, LocalTaskJob then treats the state is set externally! + self._run_finished_callback(self.task.on_success_callback, context, 'on_success') + if not test_mode: session.add(Log(self.state, self)) session.merge(self) @@ -1612,6 +1590,14 @@ def _update_ti_state_for_sensing(self, session=NEW_SESSION): # Raise exception for sensing state raise AirflowSmartSensorException("Task successfully registered in smart sensor.") + def _run_finished_callback(self, callback, context, callback_type): + """Run callback after task finishes""" + try: + if callback: + callback(context) + except Exception: # pylint: disable=broad-except + self.log.exception(f"Error when executing {callback_type} callback") + def _execute_task(self, context, task_orig): """Executes Task (optionally with a Timeout) and pushes Xcom results""" task_to_execute = self.task @@ -1713,40 +1699,6 @@ def _run_execute_callback(self, context: Context, task): except Exception: self.log.exception("Failed when executing execute callback") - def _run_finished_callback(self, error: Optional[Union[str, Exception]] = None) -> None: - """ - Call callback defined for finished state change. - - NOTE: Only invoke this function from caller of self._run_raw_task or - self.run - """ - if self.state == State.FAILED: - task = self.task - if task.on_failure_callback is not None: - context = self.get_template_context() - context["exception"] = error - try: - task.on_failure_callback(context) - except Exception: - self.log.exception("Error when executing on_failure_callback") - elif self.state == State.SUCCESS: - task = self.task - if task.on_success_callback is not None: - context = self.get_template_context() - try: - task.on_success_callback(context) - except Exception: - self.log.exception("Error when executing on_success_callback") - elif self.state == State.UP_FOR_RETRY: - task = self.task - if task.on_retry_callback is not None: - context = self.get_template_context() - context["exception"] = error - try: - task.on_retry_callback(context) - except Exception: - self.log.exception("Error when executing on_retry_callback") - @provide_session def run( self, @@ -1777,20 +1729,9 @@ def run( if not res: return - try: - error_fd = NamedTemporaryFile(delete=True) - self._run_raw_task( - mark_success=mark_success, - test_mode=test_mode, - job_id=job_id, - pool=pool, - error_file=error_fd.name, - session=session, - ) - finally: - error = None if self.state == State.SUCCESS else load_error_file(error_fd) - error_fd.close() - self._run_finished_callback(error=error) + self._run_raw_task( + mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session + ) def dry_run(self): """Only Renders Templates for the TI""" @@ -1870,28 +1811,20 @@ def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) - return tb or error.__traceback__ @provide_session - def handle_failure( - self, - error: Union[None, str, BaseException] = None, - test_mode: Optional[bool] = None, - force_fail: bool = False, - error_file: Optional[str] = None, - session: Session = NEW_SESSION, - ) -> None: + def handle_failure(self, error, test_mode=None, context=None, force_fail=False, session=None) -> None: """Handle Failure for the TaskInstance""" if test_mode is None: test_mode = self.test_mode + if context is None: + context = self.get_template_context() + if error: if isinstance(error, BaseException): tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task) self.log.error("Task failed with exception", exc_info=(type(error), error, tb)) else: self.log.error("%s", error) - # external monitoring process provides pickle file so _run_raw_task - # can send its runtime errors for access by failure callback - if error_file: - set_error_file(error_file, error) if not test_mode: self.refresh_from_db(session) @@ -1907,6 +1840,9 @@ def handle_failure( self.clear_next_method_args() + if context is not None: + context['exception'] = error + # Set state correctly and figure out how to log it and decide whether # to email @@ -1928,12 +1864,16 @@ def handle_failure( if force_fail or not self.is_eligible_to_retry(): self.state = State.FAILED email_for_state = operator.attrgetter('email_on_failure') + callback = task.on_failure_callback if task else None + callback_type = 'on_failure' else: if self.state == State.QUEUED: # We increase the try_number so as to fail the task if it fails to start after sometime self._try_number += 1 self.state = State.UP_FOR_RETRY email_for_state = operator.attrgetter('email_on_retry') + callback = task.on_retry_callback if task else None + callback_type = 'on_retry' self._log_state('Immediate failure requested. ' if force_fail else '') if task and email_for_state(task) and task.email: @@ -1942,21 +1882,13 @@ def handle_failure( except Exception: self.log.exception('Failed to send email to: %s', task.email) + if callback: + self._run_finished_callback(callback, context, callback_type) + if not test_mode: session.merge(self) session.flush() - @provide_session - def handle_failure_with_callback( - self, - error: Union[None, str, Exception], - test_mode: Optional[bool] = None, - force_fail: bool = False, - session=NEW_SESSION, - ) -> None: - self.handle_failure(error=error, test_mode=test_mode, force_fail=force_fail, session=session) - self._run_finished_callback(error=error) - def is_eligible_to_retry(self): """Is task instance is eligible for retry""" if self.state == State.RESTARTING: diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 47be386b7489a..55dcf05d34b43 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -26,12 +26,10 @@ # ignored to avoid flake complaining on Linux from pwd import getpwnam # noqa -from tempfile import NamedTemporaryFile -from typing import Optional, Union +from typing import Optional from airflow.configuration import conf from airflow.exceptions import AirflowConfigException -from airflow.models.taskinstance import load_error_file from airflow.utils.configuration import tmp_configuration_copy from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname @@ -63,8 +61,6 @@ def __init__(self, local_task_job): except AirflowConfigException: self.run_as_user = None - self._error_file = NamedTemporaryFile(delete=True) - # Add sudo commands to change user if we need to. Needed to handle SubDagOperator # case using a SequentialExecutor. self.log.debug("Planning to run as the %s user", self.run_as_user) @@ -76,9 +72,7 @@ def __init__(self, local_task_job): cfg_path = tmp_configuration_copy(chmod=0o600, include_env=True, include_cmds=True) # Give ownership of file to user; only they can read and write - subprocess.check_call( - ['sudo', 'chown', self.run_as_user, cfg_path, self._error_file.name], close_fds=True - ) + subprocess.check_call(['sudo', 'chown', self.run_as_user, cfg_path], close_fds=True) # propagate PYTHONPATH environment variable pythonpath_value = os.environ.get(PYTHONPATH_VAR, '') @@ -95,24 +89,16 @@ def __init__(self, local_task_job): cfg_path = tmp_configuration_copy(chmod=0o600, include_env=False, include_cmds=False) self._cfg_path = cfg_path - self._command = ( - popen_prepend - + self._task_instance.command_as_list( - raw=True, - pickle_id=local_task_job.pickle_id, - mark_success=local_task_job.mark_success, - job_id=local_task_job.id, - pool=local_task_job.pool, - cfg_path=cfg_path, - ) - + ["--error-file", self._error_file.name] + self._command = popen_prepend + self._task_instance.command_as_list( + raw=True, + pickle_id=local_task_job.pickle_id, + mark_success=local_task_job.mark_success, + job_id=local_task_job.id, + pool=local_task_job.pool, + cfg_path=cfg_path, ) self.process = None - def deserialize_run_error(self) -> Optional[Union[str, Exception]]: - """Return task runtime error if its written to provided error file.""" - return load_error_file(self._error_file) - def _read_task_logs(self, stream): while True: line = stream.readline() @@ -193,9 +179,3 @@ def on_finish(self) -> None: subprocess.call(['sudo', 'rm', self._cfg_path], close_fds=True) else: os.remove(self._cfg_path) - try: - self._error_file.close() - except FileNotFoundError: - # The subprocess has deleted this file before we do - # so we ignore - pass diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index f108f0f82afbf..53f873ec1f7b0 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -34,7 +34,6 @@ class StandardTaskRunner(BaseTaskRunner): def __init__(self, local_task_job): super().__init__(local_task_job) self._rc = None - self.dag = local_task_job.task_instance.task.dag def start(self): if CAN_FORK and not self.run_as_user: @@ -42,9 +41,10 @@ def start(self): else: self.process = self._start_by_exec() - def _start_by_exec(self): + def _start_by_exec(self) -> psutil.Process: subprocess = self.run_command() - return psutil.Process(subprocess.pid) + self.process = psutil.Process(subprocess.pid) + return self.process def _start_by_fork(self): pid = os.fork() @@ -62,6 +62,7 @@ def _start_by_fork(self): from airflow import settings from airflow.cli.cli_parser import get_parser from airflow.sentry import Sentry + from airflow.utils.cli import get_dag # Force a new SQLAlchemy session. We can't share open DB handles # between process. The cli code will re-create this as part of its @@ -83,9 +84,13 @@ def _start_by_fork(self): if job_id is not None: proc_title += " {0.job_id}" setproctitle(proc_title.format(args)) + return_code = 0 try: - args.func(args, dag=self.dag) + # parse dag file since `airflow tasks run --local` does not parse dag file + dag = get_dag(args.subdir, args.dag_id) + args.func(args, dag=dag) + return_code = 0 except Exception as exc: return_code = 1 diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 496a41144415a..796610803e594 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -206,6 +206,16 @@ def get_dag(subdir: Optional[str], dag_id: str) -> "DAG": return dagbag.dags[dag_id] +def get_dag_by_deserialization(dag_id: str) -> "DAG": + from airflow.models.serialized_dag import SerializedDagModel + + dag_model = SerializedDagModel.get(dag_id) + if dag_model is None: + raise AirflowException(f"Serialized DAG: {dag_id} could not be found") + + return dag_model.dag + + def get_dags(subdir: Optional[str], dag_id: str, use_regex: bool = False): """Returns DAG(s) matching a given regex or dag_id""" from airflow.models import DagBag diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 481103f1f3077..fac04fd50802d 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -36,6 +36,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, DagRunNotFound from airflow.models import DagBag, DagRun, Pool, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import timezone from airflow.utils.dates import days_ago from airflow.utils.session import create_session @@ -119,6 +120,38 @@ def test_test_with_existing_dag_run(self): ] ) + @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization") + @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") + def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserialization): + """ + Test using serialized dag for local task_run + """ + task_id = self.dag.task_ids[0] + args = [ + 'tasks', + 'run', + '--ignore-all-dependencies', + '--local', + self.dag_id, + task_id, + self.run_id, + ] + mock_get_dag_by_deserialization.return_value = SerializedDagModel.get(self.dag_id).dag + + task_command.task_run(self.parser.parse_args(args)) + mock_local_job.assert_called_once_with( + task_instance=mock.ANY, + mark_success=False, + ignore_all_deps=True, + ignore_depends_on_past=False, + ignore_task_deps=False, + ignore_ti_state=False, + pickle_id=None, + pool=None, + external_executor_id=None, + ) + mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id) + @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") def test_run_with_existing_dag_run_id(self, mock_local_job): """ diff --git a/tests/conftest.py b/tests/conftest.py index 5cbffc06f065f..c8bc240474a5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -751,3 +751,21 @@ def session(): with create_session() as session: yield session session.rollback() + + +@pytest.fixture() +def get_test_dag(): + def _get(dag_id): + from airflow.models.dagbag import DagBag + from airflow.models.serialized_dag import SerializedDagModel + + dag_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'dags', f'{dag_id}.py') + dagbag = DagBag(dag_folder=dag_file, include_examples=False) + + dag = dagbag.get_dag(dag_id) + dag.sync_to_db() + SerializedDagModel.write_dag(dag) + + return dag + + return _get diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index d7712feb8c6bd..9b8716bded720 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -366,7 +366,7 @@ def test_dag_file_processor_sla_miss_deleted_task(self, create_dummy_dag): dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) dag_file_processor.manage_slas(dag=dag, session=session) - @patch.object(TaskInstance, 'handle_failure_with_callback') + @patch.object(TaskInstance, 'handle_failure') def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) @@ -423,18 +423,14 @@ def test_failure_callbacks_should_not_drop_hostname(self): tis = session.query(TaskInstance) assert tis[0].hostname == "test_hostname" - def test_process_file_should_failure_callback(self, monkeypatch, tmp_path): + def test_process_file_should_failure_callback(self, monkeypatch, tmp_path, get_test_dag): callback_file = tmp_path.joinpath("callback.txt") callback_file.touch() monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file)) - dag_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py' - ) - dagbag = DagBag(dag_folder=dag_file, include_examples=False) dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag = dagbag.get_dag('test_om_failure_callback_dag') - task = dag.get_task(task_id='test_om_failure_callback_task') + dag = get_test_dag('test_on_failure_callback') + task = dag.get_task(task_id='test_on_failure_callback_task') with create_session() as session: dagrun = dag.create_dagrun( state=State.RUNNING, @@ -442,7 +438,7 @@ def test_process_file_should_failure_callback(self, monkeypatch, tmp_path): run_type=DagRunType.SCHEDULED, session=session, ) - (ti,) = dagrun.task_instances + ti = dagrun.get_task_instance(task.task_id) ti.refresh_from_task(task) requests = [ @@ -452,9 +448,11 @@ def test_process_file_should_failure_callback(self, monkeypatch, tmp_path): msg="Message", ) ] - dag_file_processor.process_file(dag_file, requests, session=session) + dag_file_processor.process_file(dag.fileloc, requests, session=session) - assert "Callback fired" == callback_file.read_text() + ti.refresh_from_db() + msg = ' '.join([str(k) for k in ti.key.primary]) + ' fired callback' + assert msg in callback_file.read_text() @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) def test_add_unparseable_file_before_sched_start_creates_import_error(self, tmpdir): diff --git a/tests/dags/test_dagrun_fast_follow.py b/tests/dags/test_dagrun_fast_follow.py new file mode 100644 index 0000000000000..b5f291ad12acc --- /dev/null +++ b/tests/dags/test_dagrun_fast_follow.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +from airflow.models import DAG +from airflow.operators.python import PythonOperator + +DEFAULT_DATE = datetime(2016, 1, 1) + +args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, +} + + +dag_id = 'test_dagrun_fast_follow' +dag = DAG(dag_id=dag_id, default_args=args) + +# A -> B -> C +task_a = PythonOperator(task_id='A', dag=dag, python_callable=lambda: True) +task_b = PythonOperator(task_id='B', dag=dag, python_callable=lambda: True) +task_c = PythonOperator(task_id='C', dag=dag, python_callable=lambda: True) +task_a.set_downstream(task_b) +task_b.set_downstream(task_c) + +# G -> F -> E & D -> E +task_d = PythonOperator(task_id='D', dag=dag, python_callable=lambda: True) +task_e = PythonOperator(task_id='E', dag=dag, python_callable=lambda: True) +task_f = PythonOperator(task_id='F', dag=dag, python_callable=lambda: True) +task_g = PythonOperator(task_id='G', dag=dag, python_callable=lambda: True) +task_g.set_downstream(task_f) +task_f.set_downstream(task_e) +task_d.set_downstream(task_e) + +# H -> J & I -> J +task_h = PythonOperator(task_id='H', dag=dag, python_callable=lambda: True) +task_i = PythonOperator(task_id='I', dag=dag, python_callable=lambda: True) +task_j = PythonOperator(task_id='J', dag=dag, python_callable=lambda: True) +task_h.set_downstream(task_j) +task_i.set_downstream(task_j) + +# wait_for_downstream test +# K -> L -> M +task_k = PythonOperator(task_id='K', dag=dag, python_callable=lambda: True, wait_for_downstream=True) +task_l = PythonOperator(task_id='L', dag=dag, python_callable=lambda: True, wait_for_downstream=True) +task_m = PythonOperator(task_id='M', dag=dag, python_callable=lambda: True, wait_for_downstream=True) +task_k.set_downstream(task_l) +task_l.set_downstream(task_m) diff --git a/tests/dags/test_mark_state.py b/tests/dags/test_mark_state.py new file mode 100644 index 0000000000000..e69c99fd81e27 --- /dev/null +++ b/tests/dags/test_mark_state.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from time import sleep + +from airflow.models import DAG +from airflow.operators.python import PythonOperator +from airflow.utils.session import create_session +from airflow.utils.state import State + +DEFAULT_DATE = datetime(2016, 1, 1) + +args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, +} + + +dag_id = 'test_mark_state' +dag = DAG(dag_id=dag_id, default_args=args) + + +def success_callback(context): + assert context['dag_run'].dag_id == dag_id + + +def test_mark_success_no_kill(ti): + assert ti.state == State.RUNNING + # Simulate marking this successful in the UI + with create_session() as session: + ti.state = State.SUCCESS + session.merge(ti) + session.commit() + # The below code will not run as heartbeat will detect change of state + sleep(10) + + +PythonOperator( + task_id="test_mark_success_no_kill", + python_callable=test_mark_success_no_kill, + dag=dag, + on_success_callback=success_callback, +) + + +def check_failure(context): + assert context['dag_run'].dag_id == dag_id + assert context['exception'] == "task marked as failed externally" + + +def test_mark_failure_externally(ti): + assert State.RUNNING == ti.state + with create_session() as session: + ti.log.info("Marking TI as failed 'externally'") + ti.state = State.FAILED + session.merge(ti) + session.commit() + + # This should not happen -- the state change should be noticed and the task should get killed + sleep(10) + assert False + + +PythonOperator( + task_id='test_mark_failure_externally', + python_callable=test_mark_failure_externally, + on_failure_callback=check_failure, + dag=dag, +) + + +def test_mark_skipped_externally(ti): + assert State.RUNNING == ti.state + sleep(0.1) # for timeout + with create_session() as session: + ti.log.info("Marking TI as failed 'externally'") + ti.state = State.SKIPPED + session.merge(ti) + session.commit() + + # This should not happen -- the state change should be noticed and the task should get killed + sleep(10) + assert False + + +PythonOperator(task_id='test_mark_skipped_externally', python_callable=test_mark_skipped_externally, dag=dag) + +PythonOperator(task_id='dummy', python_callable=lambda: True, dag=dag) diff --git a/tests/dags/test_mark_success.py b/tests/dags/test_mark_success.py deleted file mode 100644 index d5c05d752c27a..0000000000000 --- a/tests/dags/test_mark_success.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from time import sleep - -from airflow.models import DAG -from airflow.operators.python import PythonOperator - -DEFAULT_DATE = datetime(2016, 1, 1) - -args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, -} - - -dag = DAG(dag_id='test_mark_success', default_args=args) -task = PythonOperator(task_id='task1', python_callable=lambda x: sleep(x), op_args=[600], dag=dag) diff --git a/tests/dags/test_on_failure_callback.py b/tests/dags/test_on_failure_callback.py index 1c4d996636c54..a7f49e44c32ac 100644 --- a/tests/dags/test_on_failure_callback.py +++ b/tests/dags/test_on_failure_callback.py @@ -19,7 +19,9 @@ from datetime import datetime from airflow import DAG -from airflow.operators.empty import EmptyOperator +from airflow.exceptions import AirflowFailException +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator DEFAULT_DATE = datetime(2016, 1, 1) @@ -28,14 +30,29 @@ 'start_date': DEFAULT_DATE, } -dag = DAG(dag_id='test_om_failure_callback_dag', default_args=args) +dag = DAG(dag_id='test_on_failure_callback', default_args=args) -def write_data_to_callback(*arg, **kwargs): - with open(os.environ.get('AIRFLOW_CALLBACK_FILE'), "w+") as f: - f.write("Callback fired") +def write_data_to_callback(context): + msg = ' '.join([str(k) for k in context['ti'].key.primary]) + f' fired callback with pid: {os.getpid()}' + with open(os.environ.get('AIRFLOW_CALLBACK_FILE'), "a+") as f: + f.write(msg) -task = EmptyOperator( - task_id='test_om_failure_callback_task', dag=dag, on_failure_callback=write_data_to_callback +def task_function(ti): + raise AirflowFailException() + + +PythonOperator( + task_id='test_on_failure_callback_task', + on_failure_callback=write_data_to_callback, + python_callable=task_function, + dag=dag, +) + +BashOperator( + task_id='bash_sleep', + on_failure_callback=write_data_to_callback, + bash_command='touch $AIRFLOW_CALLBACK_FILE; sleep 10', + dag=dag, ) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 5ad82291a1988..c358ada979e5b 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -39,6 +39,7 @@ from airflow.jobs.backfill_job import BackfillJob from airflow.models import DagBag, Pool, TaskInstance as TI from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstanceKey from airflow.models.taskmap import TaskMap from airflow.operators.empty import EmptyOperator @@ -81,6 +82,9 @@ def set_instance_attrs(self, dag_bag): self.clean_db() self.parser = cli_parser.get_parser() self.dagbag = dag_bag + # `airflow tasks run` relies on serialized_dag + for dag in self.dagbag.dags.values(): + SerializedDagModel.write_dag(dag) def _get_dummy_dag( self, diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index bad4c3c91e5dc..298e5b9a6392a 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -19,9 +19,10 @@ import datetime import os import signal +import threading import time import uuid -from multiprocessing import Lock, Value +from multiprocessing import Value from typing import List, Union from unittest import mock from unittest.mock import patch @@ -30,21 +31,21 @@ import pytest from airflow import settings -from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.local_task_job import LocalTaskJob from airflow.jobs.scheduler_job import SchedulerJob from airflow.models.dagbag import DagBag +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator -from airflow.operators.python import BranchPythonOperator, PythonOperator +from airflow.operators.python import PythonOperator from airflow.task.task_runner.standard_task_runner import StandardTaskRunner from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timeout import timeout -from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from tests.test_utils import db from tests.test_utils.asserts import assert_queries_count @@ -279,39 +280,25 @@ def test_heartbeat_failed_fast(self): delta = (time2 - time1).total_seconds() assert abs(delta - job.heartrate) < 0.5 - @patch('airflow.utils.process_utils.subprocess.check_call') - @patch.object(StandardTaskRunner, 'return_code') - def test_mark_success_no_kill(self, mock_return_code, _check_call, caplog, dag_maker): + def test_mark_success_no_kill(self, caplog, get_test_dag, session): """ Test that ensures that mark_success in the UI doesn't cause the task to fail, and that the task exits """ - session = settings.Session() - - def task_function(ti): - assert ti.state == State.RUNNING - # Simulate marking this successful in the UI - ti.state = State.SUCCESS - session.merge(ti) - session.commit() - # The below code will not run as heartbeat will detect change of state - time.sleep(10) - - with dag_maker('test_mark_success'): - task = PythonOperator(task_id="task1", python_callable=task_function) - dr = dag_maker.create_dagrun() + dag = get_test_dag('test_mark_state') + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) + task = dag.get_task(task_id='test_mark_success_no_kill') - ti = dr.task_instances[0] + ti = dr.get_task_instance(task.task_id) ti.refresh_from_task(task) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) - def dummy_return_code(*args, **kwargs): - return None if not job1.terminating else -9 - - # The return code when we mark success in the UI is None - mock_return_code.side_effect = dummy_return_code - with timeout(30): job1.run() ti.refresh_from_db() @@ -389,42 +376,22 @@ def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create assert time_end - time_start < job1.heartrate assert "Task exited with return code 0" in caplog.text - def test_mark_failure_on_failure_callback(self, caplog, dag_maker): + def test_mark_failure_on_failure_callback(self, caplog, get_test_dag): """ Test that ensures that mark_failure in the UI fails the task, and executes on_failure_callback """ - # use shared memory value so we can properly track value change even if - # it's been updated across processes. - failure_callback_called = Value('i', 0) - session = settings.Session() - - def check_failure(context): - with failure_callback_called.get_lock(): - failure_callback_called.value += 1 - assert context['dag_run'].dag_id == 'test_mark_failure' - assert context['exception'] == "task marked as failed externally" - - def task_function(ti): - assert State.RUNNING == ti.state - ti.log.info("Marking TI as failed 'externally'") - ti.state = State.FAILED - session.merge(ti) - session.commit() - - # This should not happen -- the state change should be noticed and the task should get killed - time.sleep(10) - assert False - - with dag_maker("test_mark_failure", start_date=DEFAULT_DATE): - task = PythonOperator( - task_id='test_state_succeeded1', - python_callable=task_function, - on_failure_callback=check_failure, + dag = get_test_dag('test_mark_state') + with create_session() as session: + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, ) - dag_maker.create_dagrun() - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() + task = dag.get_task(task_id='test_mark_failure_externally') + ti = dr.get_task_instance(task.task_id) + ti.refresh_from_task(task) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) with timeout(30): @@ -434,40 +401,27 @@ def task_function(ti): ti.refresh_from_db() assert ti.state == State.FAILED - assert failure_callback_called.value == 1 assert "State of this instance has been externally set to failed. " "Terminating instance." in caplog.text - def test_dagrun_timeout_logged_in_task_logs(self, caplog, dag_maker): + def test_dagrun_timeout_logged_in_task_logs(self, caplog, get_test_dag): """ Test that ensures that if a running task is externally skipped (due to a dagrun timeout) It is logged in the task logs. """ - - session = settings.Session() - - def task_function(ti): - assert State.RUNNING == ti.state - time.sleep(0.1) - ti.log.info("Marking TI as skipped externally") - ti.state = State.SKIPPED - session.merge(ti) - session.commit() - - # This should not happen -- the state change should be noticed and the task should get killed - time.sleep(10) - assert False - - with dag_maker( - "test_mark_failure", start_date=DEFAULT_DATE, dagrun_timeout=datetime.timedelta(microseconds=1) - ): - task = PythonOperator( - task_id='skipped_externally', - python_callable=task_function, + dag = get_test_dag('test_mark_state') + dag.dagrun_timeout = datetime.timedelta(microseconds=1) + with create_session() as session: + dr = dag.create_dagrun( + state=State.RUNNING, + start_date=DEFAULT_DATE, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, ) - dag_maker.create_dagrun() - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() + task = dag.get_task(task_id='test_mark_skipped_externally') + ti = dr.get_task_instance(task.task_id) + ti.refresh_from_task(task) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) with timeout(30): @@ -479,476 +433,285 @@ def task_function(ti): assert ti.state == State.SKIPPED assert "DagRun timed out after " in caplog.text - @patch('airflow.utils.process_utils.subprocess.check_call') - @patch.object(StandardTaskRunner, 'return_code') - def test_failure_callback_only_called_once(self, mock_return_code, _check_call, dag_maker): + def test_failure_callback_called_by_airflow_run_raw_process(self, monkeypatch, tmp_path, get_test_dag): """ - Test that ensures that when a task exits with failure by itself, - failure callback is only called once + Ensure failure callback of a task is run by the airflow run --raw process """ - # use shared memory value so we can properly track value change even if - # it's been updated across processes. - failure_callback_called = Value('i', 0) - callback_count_lock = Lock() - - def failure_callback(context): - with callback_count_lock: - failure_callback_called.value += 1 - assert context['dag_run'].dag_id == 'test_failure_callback_race' - assert isinstance(context['exception'], AirflowFailException) - - def task_function(ti): - raise AirflowFailException() - - with dag_maker("test_failure_callback_race"): - task = PythonOperator( - task_id='test_exit_on_failure', - python_callable=task_function, - on_failure_callback=failure_callback, + callback_file = tmp_path.joinpath("callback.txt") + callback_file.touch() + monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file)) + dag = get_test_dag('test_on_failure_callback') + with create_session() as session: + dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, ) - dag_maker.create_dagrun() + task = dag.get_task(task_id='test_on_failure_callback_task') ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - - # Simulate race condition where job1 heartbeat ran right after task - # state got set to failed by ti.handle_failure but before task process - # fully exits. See _execute loop in airflow/jobs/local_task_job.py. - # In this case, we have: - # * task_runner.return_code() is None - # * ti.state == State.Failed - # - # We also need to set return_code to a valid int after job1.terminating - # is set to True so _execute loop won't loop forever. - def dummy_return_code(*args, **kwargs): - return None if not job1.terminating else -9 - - mock_return_code.side_effect = dummy_return_code - - with timeout(10): - # This should be _much_ shorter to run. - # If you change this limit, make the timeout in the callable above bigger - job1.run() + job1.run() ti.refresh_from_db() assert ti.state == State.FAILED # task exits with failure state - assert failure_callback_called.value == 1 - - @patch('airflow.utils.process_utils.subprocess.check_call') - @patch.object(StandardTaskRunner, 'return_code') - def test_mark_success_on_success_callback(self, mock_return_code, _check_call, caplog, dag_maker): + with open(callback_file) as f: + lines = f.readlines() + assert len(lines) == 1 # invoke once + assert lines[0].startswith(ti.key.primary) + this_pid = str(os.getpid()) + assert this_pid not in lines[0] + + def test_mark_success_on_success_callback(self, caplog, get_test_dag): """ Test that ensures that where a task is marked success in the UI on_success_callback gets executed """ - # use shared memory value so we can properly track value change even if - # it's been updated across processes. - success_callback_called = Value('i', 0) - session = settings.Session() - - def success_callback(context): - with success_callback_called.get_lock(): - success_callback_called.value = 1 - assert context['dag_run'].dag_id == 'test_mark_success' - - def task_function(ti): - assert ti.state == State.RUNNING - # mark it success in the UI - ti.state = State.SUCCESS - session.merge(ti) - session.commit() - # This should not happen -- the state change should be noticed and the task should get killed - time.sleep(10) - assert False - - with dag_maker(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}): - task = PythonOperator( - task_id='test_state_succeeded1', - python_callable=task_function, - on_success_callback=success_callback, + dag = get_test_dag('test_mark_state') + with create_session() as session: + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, ) - dag_maker.create_dagrun() - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() - job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + task = dag.get_task(task_id='test_mark_success_no_kill') - def dummy_return_code(*args, **kwargs): - return None if not job1.terminating else -9 + ti = dr.get_task_instance(task.task_id) + ti.refresh_from_task(task) - # The return code when we mark success in the UI is None - mock_return_code.side_effect = dummy_return_code + job = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - settings.engine.dispose() with timeout(30): - job1.run() # This should run fast because of the return_code=None + job.run() # This should run fast because of the return_code=None ti.refresh_from_db() - assert success_callback_called.value == 1 assert "State of this instance has been externally set to success. " "Terminating instance." in caplog.text - @patch('airflow.utils.process_utils.subprocess.check_call') - def test_task_sigkill_calls_on_failure_callback(self, _check_call, caplog, dag_maker): - """ - Test that ensures that when a task is killed with sigkill - on_failure_callback gets executed + @pytest.mark.parametrize("signal_type", [signal.SIGTERM, signal.SIGKILL]) + def test_process_os_signal_calls_on_failure_callback( + self, monkeypatch, tmp_path, get_test_dag, signal_type + ): """ - # use shared memory value so we can properly track value change even if - # it's been updated across processes. - failure_callback_called = Value('i', 0) - - def failure_callback(context): - with failure_callback_called.get_lock(): - failure_callback_called.value += 1 - assert context['dag_run'].dag_id == 'test_send_sigkill' + Test that ensures that when a task is killed with sigkill or sigterm + on_failure_callback does not get executed by LocalTaskJob. - def task_function(ti): - assert ti.state == State.RUNNING - os.kill(os.getpid(), signal.SIGKILL) - - with dag_maker(dag_id='test_send_sigkill'): - task = PythonOperator( - task_id='test_on_failure', - python_callable=task_function, - on_failure_callback=failure_callback, + Callbacks should not be executed by LocalTaskJob. If the task killed via sigkill, + it will be reaped as zombie, then the callback is executed + """ + callback_file = tmp_path.joinpath("callback.txt") + # callback_file will be created by the task: bash_sleep + monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file)) + dag = get_test_dag('test_on_failure_callback') + with create_session() as session: + dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, ) - dag_maker.create_dagrun() - + task = dag.get_task(task_id='bash_sleep') ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() - job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - settings.engine.dispose() - with timeout(10): - job1.run() # This should be fast because the signal is sent right away - ti.refresh_from_db() - assert failure_callback_called.value == 1 - assert "Task exited with return code Negsignal.SIGKILL" in caplog.text - @pytest.mark.quarantined - def test_process_sigterm_calls_on_failure_callback(self, caplog, dag_maker): - """ - Test that ensures that when a task runner is killed with sigterm - on_failure_callback gets executed - """ - # use shared memory value so we can properly track value change even if - # it's been updated across processes. - failure_callback_called = Value('i', 0) - - def failure_callback(context): - with failure_callback_called.get_lock(): - failure_callback_called.value += 1 - assert context['dag_run'].dag_id == 'test_mark_failure' - - def task_function(ti): - assert ti.state == State.RUNNING - os.kill(psutil.Process(os.getpid()).ppid(), signal.SIGTERM) - - with dag_maker(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}): - task = PythonOperator( - task_id='test_on_failure', - python_callable=task_function, - on_failure_callback=failure_callback, - ) - dag_maker.create_dagrun() + signal_sent_status = {'sent': False} + + def get_ti_current_pid(ti) -> str: + with create_session() as session: + pid = ( + session.query(TaskInstance.pid) + .filter( + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + ) + .one_or_none() + ) + return pid[0] + + def send_signal(ti, signal_sent, sig): + while True: + task_pid = get_ti_current_pid( + ti + ) # get pid from the db, which is the pid of airflow run --raw + if ( + task_pid and ti.current_state() == State.RUNNING and os.path.isfile(callback_file) + ): # ensure task is running before sending sig + signal_sent['sent'] = True + os.kill(task_pid, sig) + break + time.sleep(1) + + thread = threading.Thread( + name="signaler", + target=send_signal, + args=(ti, signal_sent_status, signal_type), + ) + thread.daemon = True + thread.start() - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - settings.engine.dispose() - with timeout(10): - job1.run() + job1.run() + ti.refresh_from_db() - assert failure_callback_called.value == 1 - assert "Received SIGTERM. Terminating subprocesses" in caplog.text - assert "Task exited with return code 143" in caplog.text + + assert signal_sent_status['sent'] + + if signal_type == signal.SIGTERM: + assert ti.state == State.FAILED + with open(callback_file) as f: + lines = f.readlines() + + assert len(lines) == 1 + assert lines[0].startswith(ti.key.primary) + + this_pid = str(os.getpid()) + assert this_pid not in lines[0] # ensures callback is NOT run by LocalTaskJob + assert ( + str(ti.pid) in lines[0] + ) # ensures callback is run by airflow run --raw (TaskInstance#_run_raw_task) + elif signal_type == signal.SIGKILL: + assert ( + ti.state == State.RUNNING + ) # task exits with running state, will be reaped as zombie by scheduler + with open(callback_file) as f: + lines = f.readlines() + assert len(lines) == 0 @pytest.mark.parametrize( - "conf, dependencies, init_state, first_run_state, second_run_state, error_message", + "conf, init_state, first_run_state, second_run_state, task_ids_to_run, error_message", [ ( {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'B': 'C'}, {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, + ['A', 'B'], "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.", ), ( {('scheduler', 'schedule_after_task_execution'): 'False'}, - {'A': 'B', 'B': 'C'}, {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, None, + ['A', 'B'], "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.", ), ( {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'C': 'B', 'D': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, - {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + {'D': State.QUEUED, 'E': State.NONE, 'F': State.NONE, 'G': State.NONE}, + {'D': State.SUCCESS, 'E': State.NONE, 'F': State.NONE, 'G': State.NONE}, None, - "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.", + ['D', 'E'], + "G -> F -> E & D -> E, when D runs but F isn't QUEUED yet, E shouldn't be QUEUED.", ), ( {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'C', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, + {'H': State.QUEUED, 'I': State.FAILED, 'J': State.NONE}, + {'H': State.SUCCESS, 'I': State.FAILED, 'J': State.UPSTREAM_FAILED}, None, - "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.", + ['H', 'I'], + "H -> J & I -> J, when H is QUEUED but I has FAILED, J is marked UPSTREAM_FAILED.", ), ], ) def test_fast_follow( - self, conf, dependencies, init_state, first_run_state, second_run_state, error_message, dag_maker + self, + conf, + init_state, + first_run_state, + second_run_state, + task_ids_to_run, + error_message, + get_test_dag, ): - with conf_vars(conf): - session = settings.Session() - - python_callable = lambda: True - with dag_maker('test_dagrun_fast_follow') as dag: - task_a = PythonOperator(task_id='A', python_callable=python_callable) - task_b = PythonOperator(task_id='B', python_callable=python_callable) - task_c = PythonOperator(task_id='C', python_callable=python_callable) - if 'D' in init_state: - task_d = PythonOperator(task_id='D', python_callable=python_callable) - for upstream, downstream in dependencies.items(): - dag.set_dependency(upstream, downstream) + + dag = get_test_dag( + 'test_dagrun_fast_follow', + ) scheduler_job = SchedulerJob(subdir=os.devnull) scheduler_job.dagbag.bag_dag(dag, root_dag=dag) dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING) - task_instance_a = TaskInstance(task_a, run_id=dag_run.run_id, state=init_state['A']) - - task_instance_b = TaskInstance(task_b, run_id=dag_run.run_id, state=init_state['B']) - - task_instance_c = TaskInstance(task_c, run_id=dag_run.run_id, state=init_state['C']) - - if 'D' in init_state: - task_instance_d = TaskInstance(task_d, run_id=dag_run.run_id, state=init_state['D']) - session.merge(task_instance_d) - - session.merge(task_instance_a) - session.merge(task_instance_b) - session.merge(task_instance_c) - session.flush() + ti_by_task_id = {} + with create_session() as session: + for task_id in init_state: + ti = TaskInstance(dag.get_task(task_id), run_id=dag_run.run_id, state=init_state[task_id]) + session.merge(ti) + ti_by_task_id[task_id] = ti + ti = TaskInstance(task=dag.get_task(task_ids_to_run[0]), execution_date=dag_run.execution_date) + ti.refresh_from_db() job1 = LocalTaskJob( - task_instance=task_instance_a, ignore_ti_state=True, executor=SequentialExecutor() + task_instance=ti, + ignore_ti_state=True, + executor=SequentialExecutor(), ) job1.task_runner = StandardTaskRunner(job1) - job2 = LocalTaskJob( - task_instance=task_instance_b, ignore_ti_state=True, executor=SequentialExecutor() - ) - job2.task_runner = StandardTaskRunner(job2) - - settings.engine.dispose() job1.run() self.validate_ti_states(dag_run, first_run_state, error_message) if second_run_state: + ti = TaskInstance( + task=dag.get_task(task_ids_to_run[1]), execution_date=dag_run.execution_date + ) + ti.refresh_from_db() + job2 = LocalTaskJob( + task_instance=ti, + ignore_ti_state=True, + executor=SequentialExecutor(), + ) + job2.task_runner = StandardTaskRunner(job2) job2.run() self.validate_ti_states(dag_run, second_run_state, error_message) if scheduler_job.processor_agent: scheduler_job.processor_agent.end() @conf_vars({('scheduler', 'schedule_after_task_execution'): 'True'}) - def test_mini_scheduler_works_with_wait_for_downstream(self, caplog, dag_maker): - session = settings.Session() - with dag_maker(default_args={'wait_for_downstream': True}, catchup=False) as dag: - task_a = PythonOperator(task_id='A', python_callable=lambda: True) - task_b = PythonOperator(task_id='B', python_callable=lambda: True) - task_c = PythonOperator(task_id='C', python_callable=lambda: True) - task_a >> task_b >> task_c - - scheduler_job = SchedulerJob(subdir=os.devnull) - scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag): + dag = get_test_dag('test_dagrun_fast_follow') + dag.catchup = False + SerializedDagModel.write_dag(dag) dr = dag.create_dagrun(run_id='test_1', state=State.RUNNING, execution_date=DEFAULT_DATE) dr2 = dag.create_dagrun( run_id='test_2', state=State.RUNNING, execution_date=DEFAULT_DATE + datetime.timedelta(hours=1) ) - ti_a = TaskInstance(task_a, run_id=dr.run_id, state=State.SUCCESS) - ti_b = TaskInstance(task_b, run_id=dr.run_id, state=State.SUCCESS) - ti_c = TaskInstance(task_c, run_id=dr.run_id, state=State.RUNNING) - ti2_a = TaskInstance(task_a, run_id=dr2.run_id, state=State.NONE) - ti2_b = TaskInstance(task_b, run_id=dr2.run_id, state=State.NONE) - ti2_c = TaskInstance(task_c, run_id=dr2.run_id, state=State.NONE) - session.merge(ti_a) - session.merge(ti_b) - session.merge(ti_c) - session.merge(ti2_a) - session.merge(ti2_b) - session.merge(ti2_c) - session.flush() - - job1 = LocalTaskJob(task_instance=ti2_a, ignore_ti_state=True, executor=SequentialExecutor()) - job1.task_runner = StandardTaskRunner(job1) - t = time.time() - job1.run() - d = time.time() - t - - ti2_a.refresh_from_db(session) - ti2_b.refresh_from_db(session) - assert ti2_a.state == State.SUCCESS - assert ti2_b.state == State.NONE - assert ( - "0 downstream tasks scheduled from follow-on schedule" in caplog.text - ), f"Failed after {d.total_seconds()}: {caplog.text}" - - failed_deps = list(ti2_b.get_failed_dep_statuses(session=session)) - assert len(failed_deps) == 1 - assert failed_deps[0].dep_name == "Previous Dagrun State" - assert not failed_deps[0].passed - - @pytest.mark.parametrize( - "exception, trigger_rule", - [ - (AirflowFailException(), TriggerRule.ALL_DONE), - (AirflowFailException(), TriggerRule.ALL_FAILED), - (AirflowSkipException(), TriggerRule.ALL_DONE), - (AirflowSkipException(), TriggerRule.ALL_SKIPPED), - (AirflowSkipException(), TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS), - ], - ) - @conf_vars({('scheduler', 'schedule_after_task_execution'): 'True'}) - def test_mini_scheduler_works_with_skipped_and_failed( - self, exception, trigger_rule, caplog, session, dag_maker - ): - """ - In these cases D is running, at no decision can be made about C. - """ - - def raise_(): - raise exception - - with dag_maker(catchup=False) as dag: - task_a = PythonOperator(task_id='A', python_callable=raise_) - task_b = PythonOperator(task_id='B', python_callable=lambda: True) - task_c = PythonOperator(task_id='C', python_callable=lambda: True, trigger_rule=trigger_rule) - task_d = PythonOperator(task_id='D', python_callable=lambda: True) - task_a >> task_b >> task_c - task_d >> task_c - - dr = dag.create_dagrun(run_id='test_1', state=State.RUNNING, execution_date=DEFAULT_DATE) - ti_a = TaskInstance(task_a, run_id=dr.run_id, state=State.QUEUED) - ti_b = TaskInstance(task_b, run_id=dr.run_id, state=State.NONE) - ti_c = TaskInstance(task_c, run_id=dr.run_id, state=State.NONE) - ti_d = TaskInstance(task_d, run_id=dr.run_id, state=State.RUNNING) - - session.merge(ti_a) - session.merge(ti_b) - session.merge(ti_c) - session.merge(ti_d) - session.flush() - - job1 = LocalTaskJob(task_instance=ti_a, ignore_ti_state=True, executor=SequentialExecutor()) - job1.task_runner = StandardTaskRunner(job1) - job1.run() + task_k = dag.get_task('K') + task_l = dag.get_task('L') + with create_session() as session: + ti_k = TaskInstance(task_k, run_id=dr.run_id, state=State.SUCCESS) + ti_b = TaskInstance(task_l, run_id=dr.run_id, state=State.SUCCESS) - ti_b.refresh_from_db(session) - ti_c.refresh_from_db(session) - assert ti_b.state in (State.SKIPPED, State.UPSTREAM_FAILED) - assert ti_c.state == State.NONE - assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text + ti2_k = TaskInstance(task_k, run_id=dr2.run_id, state=State.NONE) + ti2_l = TaskInstance(task_l, run_id=dr2.run_id, state=State.NONE) - failed_deps = list(ti_c.get_failed_dep_statuses(session=session)) - assert len(failed_deps) == 1 - assert failed_deps[0].dep_name == "Trigger Rule" - assert not failed_deps[0].passed + session.merge(ti_k) + session.merge(ti_b) - @pytest.mark.parametrize( - "trigger_rule", - [ - TriggerRule.ONE_SUCCESS, - TriggerRule.ALL_SKIPPED, - TriggerRule.NONE_FAILED, - TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, - ], - ) - @conf_vars({('scheduler', 'schedule_after_task_execution'): 'True'}) - def test_mini_scheduler_works_with_branch_python_operator(self, trigger_rule, caplog, session, dag_maker): - """ - In these cases D is running, at no decision can be made about C. - """ - with dag_maker(catchup=False) as dag: - task_a = BranchPythonOperator(task_id='A', python_callable=lambda: []) - task_b = PythonOperator(task_id='B', python_callable=lambda: True) - task_c = PythonOperator( - task_id='C', - python_callable=lambda: True, - trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, - ) - task_d = PythonOperator(task_id='D', python_callable=lambda: True) - task_a >> task_b >> task_c - task_d >> task_c + session.merge(ti2_k) + session.merge(ti2_l) - dr = dag.create_dagrun(run_id='test_1', state=State.RUNNING, execution_date=DEFAULT_DATE) - ti_a = TaskInstance(task_a, run_id=dr.run_id, state=State.QUEUED) - ti_b = TaskInstance(task_b, run_id=dr.run_id, state=State.NONE) - ti_c = TaskInstance(task_c, run_id=dr.run_id, state=State.NONE) - ti_d = TaskInstance(task_d, run_id=dr.run_id, state=State.RUNNING) - - session.merge(ti_a) - session.merge(ti_b) - session.merge(ti_c) - session.merge(ti_d) - session.flush() - - job1 = LocalTaskJob(task_instance=ti_a, ignore_ti_state=True, executor=SequentialExecutor()) + job1 = LocalTaskJob(task_instance=ti2_k, ignore_ti_state=True, executor=SequentialExecutor()) job1.task_runner = StandardTaskRunner(job1) job1.run() - ti_b.refresh_from_db(session) - ti_c.refresh_from_db(session) - assert ti_b.state == State.SKIPPED - assert ti_c.state == State.NONE + ti2_k.refresh_from_db() + ti2_l.refresh_from_db() + assert ti2_k.state == State.SUCCESS + assert ti2_l.state == State.NONE assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text - failed_deps = list(ti_c.get_failed_dep_statuses(session=session)) + failed_deps = list(ti2_l.get_failed_dep_statuses()) assert len(failed_deps) == 1 - assert failed_deps[0].dep_name == "Trigger Rule" + assert failed_deps[0].dep_name == "Previous Dagrun State" assert not failed_deps[0].passed - @patch('airflow.utils.process_utils.subprocess.check_call') - def test_task_sigkill_works_with_retries(self, _check_call, caplog, dag_maker): - """ - Test that ensures that tasks are retried when they receive sigkill - """ - # use shared memory value so we can properly track value change even if - # it's been updated across processes. - retry_callback_called = Value('i', 0) - - def retry_callback(context): - with retry_callback_called.get_lock(): - retry_callback_called.value += 1 - assert context['dag_run'].dag_id == 'test_mark_failure_2' - - def task_function(ti): - os.kill(os.getpid(), signal.SIGKILL) - - with dag_maker( - dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} - ): - task = PythonOperator( - task_id='test_on_failure', - python_callable=task_function, - retries=1, - on_retry_callback=retry_callback, - ) - dr = dag_maker.create_dagrun() - ti = dr.task_instances[0] - ti.refresh_from_task(task) - job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - settings.engine.dispose() - with timeout(10): - job1.run() - assert retry_callback_called.value == 1 - assert "Task exited with return code Negsignal.SIGKILL" in caplog.text - @pytest.mark.quarantined def test_process_sigterm_works_with_retries(self, caplog, dag_maker): """ @@ -987,34 +750,6 @@ def task_function(ti): assert "Received SIGTERM. Terminating subprocesses" in caplog.text assert "Task exited with return code 143" in caplog.text - def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self, dag_maker): - """Test that with DAG paused, DagRun state will update when the tasks finishes the run""" - schedule_interval = datetime.timedelta(days=1) - with dag_maker(dag_id='test_dags', schedule_interval=schedule_interval) as dag: - op1 = PythonOperator(task_id='dummy', python_callable=lambda: True) - - session = settings.Session() - dagmodel = dag_maker.dag_model - dagmodel.next_dagrun_create_after = DEFAULT_DATE + schedule_interval - dagmodel.is_paused = True - session.merge(dagmodel) - session.flush() - # Write Dag to DB - dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False) - dagbag.bag_dag(dag, root_dag=dag) - dagbag.sync_to_db() - - dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) - - assert dr.state == State.RUNNING - ti = TaskInstance(op1, dr.execution_date) - job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - job1.task_runner = StandardTaskRunner(job1) - job1.run() - session.add(dr) - session.refresh(dr) - assert dr.state == State.SUCCESS - @pytest.fixture() def clean_db_helper(): diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e53b52e11b1c3..4fe5cb75f2662 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -23,7 +23,6 @@ import signal import sys import urllib -from tempfile import NamedTemporaryFile from traceback import format_exception from typing import List, Optional, Union, cast from unittest import mock @@ -56,8 +55,9 @@ Variable, XCom, ) +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskfail import TaskFail -from airflow.models.taskinstance import TaskInstance, load_error_file, set_error_file +from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom import XCOM_RETURN_KEY from airflow.operators.bash import BashOperator @@ -80,7 +80,6 @@ from airflow.version import version from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER from tests.test_utils import db -from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_connections, clear_db_runs @@ -111,15 +110,7 @@ def wrap_task_instance(self, ti): def success_handler(self, context): self.callback_ran = True - session = settings.Session() - temp_instance = ( - session.query(TI) - .filter(TI.task_id == self.task_id) - .filter(TI.dag_id == self.dag_id) - .filter(TI.execution_date == self.execution_date) - .one() - ) - self.task_state_in_callback = temp_instance.state + self.task_state_in_callback = context['ti'].state class TestTaskInstance: @@ -142,17 +133,6 @@ def setup_method(self): def teardown_method(self): self.clean_db() - def test_load_error_file_returns_None_for_closed_file(self): - error_fd = NamedTemporaryFile() - error_fd.close() - assert load_error_file(error_fd) is None - - def test_load_error_file_loads_correctly(self): - error_message = "some random error message" - with NamedTemporaryFile() as error_fd: - set_error_file(error_fd.name, error=error_message) - assert load_error_file(error_fd) == error_message - def test_set_task_dates(self, dag_maker): """ Test that tasks properly take start/end dates from DAGs @@ -1074,7 +1054,7 @@ def test_xcom_pull_mapped(self, dag_maker, session): ti_1_0 = dagrun.get_task_instance("task_1", session=session) ti_1_0.map_index = 0 - ti_1_1 = session.merge(TaskInstance(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) + ti_1_1 = session.merge(TI(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) session.flush() ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session) @@ -1223,18 +1203,56 @@ def post_execute(self, context, result=None): def test_check_and_change_state_before_execution(self, create_task_instance): ti = create_task_instance(dag_id='test_check_and_change_state_before_execution') - assert ti._try_number == 0 - assert ti.check_and_change_state_before_execution() + SerializedDagModel.write_dag(ti.task.dag) + + serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag + ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + + assert ti_from_deserialized_task._try_number == 0 + assert ti_from_deserialized_task.check_and_change_state_before_execution() # State should be running, and try_number column should be incremented - assert ti.state == State.RUNNING - assert ti._try_number == 1 + assert ti_from_deserialized_task.state == State.RUNNING + assert ti_from_deserialized_task._try_number == 1 def test_check_and_change_state_before_execution_dep_not_met(self, create_task_instance): ti = create_task_instance(dag_id='test_check_and_change_state_before_execution') task2 = EmptyOperator(task_id='task2', dag=ti.task.dag, start_date=DEFAULT_DATE) ti.task >> task2 - ti = TI(task=task2, run_id=ti.run_id) - assert not ti.check_and_change_state_before_execution() + SerializedDagModel.write_dag(ti.task.dag) + + serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag + ti2 = TI(task=serialized_dag.get_task(task2.task_id), run_id=ti.run_id) + assert not ti2.check_and_change_state_before_execution() + + def test_check_and_change_state_before_execution_dep_not_met_already_running(self, create_task_instance): + """return False if the task instance state is running""" + ti = create_task_instance(dag_id='test_check_and_change_state_before_execution') + with create_session() as _: + ti.state = State.RUNNING + + SerializedDagModel.write_dag(ti.task.dag) + + serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag + ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + + assert not ti_from_deserialized_task.check_and_change_state_before_execution() + assert ti_from_deserialized_task.state == State.RUNNING + + def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( + self, create_task_instance + ): + """return False if the task instance state is failed""" + ti = create_task_instance(dag_id='test_check_and_change_state_before_execution') + with create_session() as _: + ti.state = State.FAILED + + SerializedDagModel.write_dag(ti.task.dag) + + serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag + ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + + assert not ti_from_deserialized_task.check_and_change_state_before_execution() + assert ti_from_deserialized_task.state == State.FAILED def test_try_number(self, create_task_instance): """ @@ -1453,7 +1471,6 @@ def test_success_callback_no_race_condition(self, create_task_instance): callback_wrapper.wrap_task_instance(ti) ti._run_raw_task() - ti._run_finished_callback() assert callback_wrapper.callback_ran assert callback_wrapper.task_state_in_callback == State.SUCCESS ti.refresh_from_db() @@ -1796,15 +1813,15 @@ def on_execute_callable(context): assert ti.state == State.SUCCESS @pytest.mark.parametrize( - "finished_state, expected_message", + "finished_state, callback_type", [ - (State.SUCCESS, "Error when executing on_success_callback"), - (State.UP_FOR_RETRY, "Error when executing on_retry_callback"), - (State.FAILED, "Error when executing on_failure_callback"), + (State.SUCCESS, "on_success"), + (State.UP_FOR_RETRY, "on_retry"), + (State.FAILED, "on_failure"), ], ) def test_finished_callbacks_handle_and_log_exception( - self, finished_state, expected_message, create_task_instance + self, finished_state, callback_type, create_task_instance ): called = completed = False @@ -1822,10 +1839,11 @@ def on_finish_callable(context): state=finished_state, ) ti._log = mock.Mock() - ti._run_finished_callback() + ti._run_finished_callback(on_finish_callable, {}, callback_type) assert called assert not completed + expected_message = f"Error when executing {callback_type} callback" ti.log.exception.assert_called_once_with(expected_message) @provide_session @@ -1857,7 +1875,6 @@ def test_handle_failure(self, create_dummy_dag, session=None): ti1.task = task1 ti1.state = State.FAILED ti1.handle_failure("test failure handling") - ti1._run_finished_callback() context_arg_1 = mock_on_failure_1.call_args[0][0] assert context_arg_1 and "task_instance" in context_arg_1 @@ -1877,7 +1894,6 @@ def test_handle_failure(self, create_dummy_dag, session=None): session.add(ti2) session.flush() ti2.handle_failure("test retry handling") - ti2._run_finished_callback() mock_on_failure_2.assert_not_called() @@ -1899,7 +1915,6 @@ def test_handle_failure(self, create_dummy_dag, session=None): session.flush() ti3.state = State.FAILED ti3.handle_failure("test force_fail handling", force_fail=True) - ti3._run_finished_callback() context_arg_3 = mock_on_failure_3.call_args[0][0] assert context_arg_3 and "task_instance" in context_arg_3 @@ -2312,41 +2327,6 @@ def setup_method(self) -> None: def teardown_method(self) -> None: self._clean() - @pytest.mark.parametrize("expected_query_count, mark_success", [(12, False), (5, True)]) - @provide_session - def test_execute_queries_count( - self, expected_query_count, mark_success, create_task_instance, session=None - ): - ti = create_task_instance(session=session, state=State.RUNNING) - assert ti.dag_run - - # an extra query is fired in RenderedTaskInstanceFields.delete_old_records - # for other DBs. delete_old_records is called only when mark_success is False - expected_query_count_based_on_db = ( - expected_query_count + 1 - if session.bind.dialect.name == "mssql" and expected_query_count > 0 and not mark_success - else expected_query_count - ) - - session.flush() - - with assert_queries_count(expected_query_count_based_on_db): - ti._run_raw_task(mark_success=mark_success, session=session) - - @provide_session - def test_execute_queries_count_store_serialized(self, create_task_instance, session=None): - ti = create_task_instance(session=session, state=State.RUNNING) - assert ti.dag_run - - # an extra query is fired in RenderedTaskInstanceFields.delete_old_records - # for other DBs - expected_query_count_based_on_db = 5 - - session.flush() - - with assert_queries_count(expected_query_count_based_on_db): - ti._run_raw_task(session) - @pytest.mark.parametrize("mode", ["poke", "reschedule"]) @pytest.mark.parametrize("retries", [0, 1]) diff --git a/tests/task/task_runner/test_base_task_runner.py b/tests/task/task_runner/test_base_task_runner.py index a257a8e883136..b6d29e55b4b28 100644 --- a/tests/task/task_runner/test_base_task_runner.py +++ b/tests/task/task_runner/test_base_task_runner.py @@ -47,7 +47,7 @@ def test_config_copy_mode(tmp_configuration_copy, subprocess_call, dag_maker, im if impersonation: subprocess_call.assert_called_with( - ['sudo', 'chown', impersonation, "/tmp/some-string", runner._error_file.name], close_fds=True + ['sudo', 'chown', impersonation, "/tmp/some-string"], close_fds=True ) else: subprocess_call.not_assert_called() diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index e7a51d94ca36e..2ed266a54321c 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -17,10 +17,8 @@ # under the License. import logging import os -import re import time from logging.config import dictConfig -from tempfile import NamedTemporaryFile from unittest import mock import psutil @@ -201,7 +199,7 @@ def test_on_kill(self): dag = dagbag.dags.get('test_on_kill') task = dag.get_task('task1') - with create_session() as session, NamedTemporaryFile("w", delete=False) as f: + with create_session() as session: dag.create_dagrun( run_id="test", data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -215,9 +213,6 @@ def test_on_kill(self): ti.refresh_from_task(task) runner = StandardTaskRunner(job1) - handler = logging.StreamHandler(f) - handler.setFormatter(logging.Formatter(TASK_FORMAT)) - runner.log.addHandler(handler) runner.start() with timeout(seconds=3): @@ -240,14 +235,8 @@ def test_on_kill(self): logging.info(f"Terminating processes {processes} belonging to {runner_pgid} group") runner.terminate() session.close() # explicitly close as `create_session`s commit will blow up otherwise - with open(f.name) as g: - logged = g.read() - os.unlink(f.name) ti.refresh_from_db() - assert re.findall(r'ERROR - Failed to execute job (\S+) for task (\S+)', logged) == [ - (str(ti.job_id), ti.task_id) - ], logged logging.info("Waiting for the on kill killed file to appear") with timeout(seconds=4):