Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix task retries when they receive sigkill and have retries. Properly handle sigterm too #16301

Merged
merged 1 commit into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,9 @@ def _execute(self):
def signal_handler(signum, frame):
"""Setting kill signal handler"""
self.log.error("Received SIGTERM. Terminating subprocesses")
self.on_kill()
self.task_instance.refresh_from_db()
if self.task_instance.state not in State.finished:
self.task_instance.set_state(State.FAILED)
self.task_instance._run_finished_callback(error="task received sigterm")
raise AirflowException("LocalTaskJob received SIGTERM signal")
self.task_runner.terminate()
self.handle_task_exit(128 + signum)
return

signal.signal(signal.SIGTERM, signal_handler)

Expand Down Expand Up @@ -148,16 +145,19 @@ def signal_handler(signum, frame):
self.on_kill()

def handle_task_exit(self, return_code: int) -> None:
"""Handle case where self.task_runner exits by itself"""
"""Handle case where self.task_runner exits by itself or is externally killed"""
# Without setting this, heartbeat may get us
self.terminating = True
self.log.info("Task exited with return code %s", return_code)
self.task_instance.refresh_from_db()
# task exited by itself, so we need to check for error file

if self.task_instance.state == State.RUNNING:
# This is for a case where the task received a SIGKILL
# while running or the task runner received a sigterm
self.task_instance.handle_failure(error=None)
# We need to check for error file
# in case it failed due to runtime exception/error
error = None
if self.task_instance.state == State.RUNNING:
# This is for a case where the task received a sigkill
# while running
self.task_instance.set_state(State.FAILED)
if self.task_instance.state != State.SUCCESS:
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error)
Expand Down
169 changes: 146 additions & 23 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import signal
import time
import uuid
from datetime import timedelta
from multiprocessing import Lock, Value
from unittest import mock
from unittest.mock import patch
Expand Down Expand Up @@ -228,7 +229,6 @@ def test_heartbeat_failed_fast(self):
delta = (time2 - time1).total_seconds()
assert abs(delta - job.heartrate) < 0.5

@pytest.mark.quarantined
def test_mark_success_no_kill(self):
"""
Test that ensures that mark_success in the UI doesn't cause
Expand Down Expand Up @@ -256,7 +256,6 @@ def test_mark_success_no_kill(self):
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
process = multiprocessing.Process(target=job1.run)
process.start()
ti.refresh_from_db()
for _ in range(0, 50):
if ti.state == State.RUNNING:
break
Expand Down Expand Up @@ -466,7 +465,6 @@ def dummy_return_code(*args, **kwargs):
assert ti.state == State.FAILED # task exits with failure state
assert failure_callback_called.value == 1

@pytest.mark.quarantined
def test_mark_success_on_success_callback(self, dag_maker):
"""
Test that ensures that where a task is marked success in the UI
Expand Down Expand Up @@ -523,16 +521,9 @@ def task_function(ti):
assert task_terminated_externally.value == 1
assert not process.is_alive()

@parameterized.expand(
[
(signal.SIGTERM,),
(signal.SIGKILL,),
]
)
@pytest.mark.quarantined
def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker):
def test_task_sigkill_calls_on_failure_callback(self, dag_maker):
"""
Test that ensures that when a task is killed with sigterm or sigkill
Test that ensures that when a task is killed with sigkill
on_failure_callback gets executed
"""
# use shared memory value so we can properly track value change even if
Expand All @@ -544,10 +535,50 @@ def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker):
def failure_callback(context):
with shared_mem_lock:
failure_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_failure'
assert context['dag_run'].dag_id == 'test_send_sigkill'

def task_function(ti):
os.kill(os.getpid(), signal.SIGKILL)
# This should not happen -- the state change should be noticed and the task should get killed
with shared_mem_lock:
task_terminated_externally.value = 0

with dag_maker(dag_id='test_send_sigkill'):
task = PythonOperator(
task_id='test_on_failure',
python_callable=task_function,
on_failure_callback=failure_callback,
)

ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
settings.engine.dispose()
process = multiprocessing.Process(target=job1.run)
process.start()
time.sleep(0.3)
process.join(timeout=10)
assert failure_callback_called.value == 1
assert task_terminated_externally.value == 1
assert not process.is_alive()

def test_process_sigterm_calls_on_failure_callback(self, dag_maker):
"""
Test that ensures that when a task runner is killed with sigterm
on_failure_callback gets executed
"""
# use shared memory value so we can properly track value change even if
# it's been updated across processes.
failure_callback_called = Value('i', 0)
task_terminated_externally = Value('i', 1)
shared_mem_lock = Lock()

def failure_callback(context):
with shared_mem_lock:
failure_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_failure'

def task_function(ti):
time.sleep(60)
# This should not happen -- the state change should be noticed and the task should get killed
with shared_mem_lock:
Expand All @@ -562,20 +593,16 @@ def task_function(ti):
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
job1.task_runner = StandardTaskRunner(job1)

settings.engine.dispose()
process = multiprocessing.Process(target=job1.run)
process.start()

for _ in range(0, 20):
for _ in range(0, 25):
ti.refresh_from_db()
if ti.state == State.RUNNING and ti.pid is not None:
if ti.state == State.RUNNING:
break
time.sleep(0.2)
assert ti.pid is not None
assert ti.state == State.RUNNING
os.kill(ti.pid, signal_type)
os.kill(process.pid, signal.SIGTERM)
ti.refresh_from_db()
process.join(timeout=10)
assert failure_callback_called.value == 1
assert task_terminated_externally.value == 1
Expand Down Expand Up @@ -683,7 +710,103 @@ def test_fast_follow(
if scheduler_job.processor_agent:
scheduler_job.processor_agent.end()

def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self, dag_maker):
def test_task_sigkill_works_with_retries(self, dag_maker):
"""
Test that ensures that tasks are retried when they receive sigkill
"""
# use shared memory value so we can properly track value change even if
# it's been updated across processes.
retry_callback_called = Value('i', 0)
task_terminated_externally = Value('i', 1)
shared_mem_lock = Lock()

def retry_callback(context):
with shared_mem_lock:
retry_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_failure_2'

def task_function(ti):
os.kill(os.getpid(), signal.SIGKILL)
# This should not happen -- the state change should be noticed and the task should get killed
with shared_mem_lock:
task_terminated_externally.value = 0

with dag_maker(
dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
):
task = PythonOperator(
task_id='test_on_failure',
python_callable=task_function,
retries=1,
retry_delay=timedelta(seconds=2),
on_retry_callback=retry_callback,
)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
job1.task_runner = StandardTaskRunner(job1)
job1.task_runner.start()
settings.engine.dispose()
process = multiprocessing.Process(target=job1.run)
process.start()
time.sleep(0.4)
process.join()
ti.refresh_from_db()
assert ti.state == State.UP_FOR_RETRY
assert retry_callback_called.value == 1
assert task_terminated_externally.value == 1

def test_process_sigterm_works_with_retries(self, dag_maker):
"""
Test that ensures that task runner sets tasks to retry when they(task runner)
receive sigterm
"""
# use shared memory value so we can properly track value change even if
# it's been updated across processes.
retry_callback_called = Value('i', 0)
task_terminated_externally = Value('i', 1)
shared_mem_lock = Lock()

def retry_callback(context):
with shared_mem_lock:
retry_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_failure_2'

def task_function(ti):
time.sleep(60)
# This should not happen -- the state change should be noticed and the task should get killed
with shared_mem_lock:
task_terminated_externally.value = 0

with dag_maker(dag_id='test_mark_failure_2'):
task = PythonOperator(
task_id='test_on_failure',
python_callable=task_function,
retries=1,
retry_delay=timedelta(seconds=2),
on_retry_callback=retry_callback,
)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
job1.task_runner = StandardTaskRunner(job1)
job1.task_runner.start()
settings.engine.dispose()
process = multiprocessing.Process(target=job1.run)
process.start()
for _ in range(0, 25):
ti.refresh_from_db()
if ti.state == State.RUNNING and ti.pid is not None:
break
time.sleep(0.2)
os.kill(process.pid, signal.SIGTERM)
process.join()
ti.refresh_from_db()
assert ti.state == State.UP_FOR_RETRY
assert retry_callback_called.value == 1
assert task_terminated_externally.value == 1

def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
"""Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True)
Expand Down Expand Up @@ -746,5 +869,5 @@ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes)
mock_get_task_runner.return_value.return_code.side_effects = return_codes

job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
with assert_queries_count(16):
with assert_queries_count(18):
job.run()
32 changes: 32 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import datetime
import os
import signal
import time
import unittest
import urllib
Expand Down Expand Up @@ -530,6 +531,37 @@ def raise_skip_exception():
ti.run()
assert State.SKIPPED == ti.state

def test_task_sigterm_works_with_retries(self):
"""
Test that ensures that tasks are retried when they receive sigterm
"""
dag = DAG(dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

def task_function(ti):
# pylint: disable=unused-argument
os.kill(ti.pid, signal.SIGTERM)

task = PythonOperator(
task_id='test_on_failure',
python_callable=task_function,
retries=1,
retry_delay=datetime.timedelta(seconds=2),
dag=dag,
)

dag.create_dagrun(
run_id="test",
state=State.RUNNING,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
)
ti = TI(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
with self.assertRaises(AirflowException):
ti.run()
ti.refresh_from_db()
assert ti.state == State.UP_FOR_RETRY

def test_retry_delay(self):
"""
Test that retry delays are respected
Expand Down