Skip to content

Commit 162fc5c

Browse files
jmcarpwayne.morris
authored and
wayne.morris
committed
[AIRFLOW-3129] Improve test coverage of airflow.models. (apache#3982)
1 parent 1db3c1e commit 162fc5c

File tree

2 files changed

+180
-26
lines changed

2 files changed

+180
-26
lines changed

airflow/models.py

-25
Original file line numberDiff line numberDiff line change
@@ -590,21 +590,6 @@ def dagbag_report(self):
590590
table=pprinttable(stats),
591591
)
592592

593-
@provide_session
594-
def deactivate_inactive_dags(self, session=None):
595-
active_dag_ids = [dag.dag_id for dag in list(self.dags.values())]
596-
for dag in session.query(
597-
DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
598-
dag.is_active = False
599-
session.merge(dag)
600-
session.commit()
601-
602-
@provide_session
603-
def paused_dags(self, session=None):
604-
dag_ids = [dp.dag_id for dp in session.query(DagModel).filter(
605-
DagModel.is_paused.__eq__(True))]
606-
return dag_ids
607-
608593

609594
class User(Base):
610595
__tablename__ = "users"
@@ -4202,16 +4187,6 @@ def add_tasks(self, tasks):
42024187
for task in tasks:
42034188
self.add_task(task)
42044189

4205-
@provide_session
4206-
def db_merge(self, session=None):
4207-
BO = BaseOperator
4208-
tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all()
4209-
for t in tasks:
4210-
session.delete(t)
4211-
session.commit()
4212-
session.merge(self)
4213-
session.commit()
4214-
42154190
def run(
42164191
self,
42174192
start_date=None,

tests/models.py

+180-1
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,20 @@
4343
from airflow.models import clear_task_instances
4444
from airflow.models import XCom
4545
from airflow.models import Connection
46+
from airflow.models import SkipMixin
47+
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier
4648
from airflow.jobs import LocalTaskJob
4749
from airflow.operators.dummy_operator import DummyOperator
4850
from airflow.operators.bash_operator import BashOperator
4951
from airflow.operators.python_operator import PythonOperator
5052
from airflow.operators.python_operator import ShortCircuitOperator
53+
from airflow.operators.subdag_operator import SubDagOperator
5154
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
5255
from airflow.utils import timezone
5356
from airflow.utils.weight_rule import WeightRule
5457
from airflow.utils.state import State
5558
from airflow.utils.trigger_rule import TriggerRule
56-
from mock import patch, ANY
59+
from mock import patch, Mock, ANY
5760
from parameterized import parameterized
5861
from tempfile import mkdtemp, NamedTemporaryFile
5962

@@ -575,6 +578,38 @@ def test_cycle(self):
575578
with self.assertRaises(AirflowDagCycleException):
576579
dag.test_cycle()
577580

581+
@patch('airflow.models.timezone.utcnow')
582+
def test_sync_to_db(self, mock_now):
583+
dag = DAG(
584+
'dag',
585+
start_date=DEFAULT_DATE,
586+
)
587+
with dag:
588+
DummyOperator(task_id='task', owner='owner1')
589+
SubDagOperator(
590+
task_id='subtask',
591+
owner='owner2',
592+
subdag=DAG(
593+
'dag.subtask',
594+
start_date=DEFAULT_DATE,
595+
)
596+
)
597+
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
598+
mock_now.return_value = now
599+
session = settings.Session()
600+
dag.sync_to_db(session=session)
601+
602+
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
603+
self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'})
604+
self.assertEqual(orm_dag.last_scheduler_run, now)
605+
self.assertTrue(orm_dag.is_active)
606+
607+
orm_subdag = session.query(DagModel).filter(
608+
DagModel.dag_id == 'dag.subtask').one()
609+
self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'})
610+
self.assertEqual(orm_subdag.last_scheduler_run, now)
611+
self.assertTrue(orm_subdag.is_active)
612+
578613

579614
class DagStatTest(unittest.TestCase):
580615
def test_dagstats_crud(self):
@@ -625,6 +660,25 @@ def test_dagstats_crud(self):
625660
for stat in res:
626661
self.assertFalse(stat.dirty)
627662

663+
def test_update_exception(self):
664+
session = Mock()
665+
(session.query.return_value
666+
.filter.return_value
667+
.with_for_update.return_value
668+
.all.side_effect) = RuntimeError('it broke')
669+
DagStat.update(session=session)
670+
session.rollback.assert_called()
671+
672+
def test_set_dirty_exception(self):
673+
session = Mock()
674+
session.query.return_value.filter.return_value.all.return_value = []
675+
(session.query.return_value
676+
.filter.return_value
677+
.with_for_update.return_value
678+
.all.side_effect) = RuntimeError('it broke')
679+
DagStat.set_dirty('dag', session)
680+
session.rollback.assert_called()
681+
628682

629683
class DagRunTest(unittest.TestCase):
630684

@@ -2349,6 +2403,35 @@ def test_overwrite_params_with_dag_run_conf_none(self):
23492403

23502404
self.assertEqual(False, params["override"])
23512405

2406+
@patch('airflow.models.send_email')
2407+
def test_email_alert(self, mock_send_email):
2408+
task = DummyOperator(task_id='op', email='test@test.test')
2409+
ti = TI(task=task, execution_date=datetime.datetime.now())
2410+
ti.email_alert(RuntimeError('it broke'))
2411+
2412+
self.assertTrue(mock_send_email.called)
2413+
(email, title, body), _ = mock_send_email.call_args
2414+
self.assertEqual(email, 'test@test.test')
2415+
self.assertIn(repr(ti), title)
2416+
self.assertIn('it broke', body)
2417+
2418+
def test_set_duration(self):
2419+
task = DummyOperator(task_id='op', email='test@test.test')
2420+
ti = TI(
2421+
task=task,
2422+
execution_date=datetime.datetime.now(),
2423+
)
2424+
ti.start_date = datetime.datetime(2018, 10, 1, 1)
2425+
ti.end_date = datetime.datetime(2018, 10, 1, 2)
2426+
ti.set_duration()
2427+
self.assertEqual(ti.duration, 3600)
2428+
2429+
def test_set_duration_empty_dates(self):
2430+
task = DummyOperator(task_id='op', email='test@test.test')
2431+
ti = TI(task=task, execution_date=datetime.datetime.now())
2432+
ti.set_duration()
2433+
self.assertIsNone(ti.duration)
2434+
23522435

23532436
class ClearTasksTest(unittest.TestCase):
23542437

@@ -2705,3 +2788,99 @@ def test_connection_from_uri_with_extras(self):
27052788
self.assertEqual(connection.port, 1234)
27062789
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
27072790
'extra2': '/path/'})
2791+
2792+
2793+
class TestSkipMixin(unittest.TestCase):
2794+
2795+
@patch('airflow.models.timezone.utcnow')
2796+
def test_skip(self, mock_now):
2797+
session = settings.Session()
2798+
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
2799+
mock_now.return_value = now
2800+
dag = DAG(
2801+
'dag',
2802+
start_date=DEFAULT_DATE,
2803+
)
2804+
with dag:
2805+
tasks = [DummyOperator(task_id='task')]
2806+
dag_run = dag.create_dagrun(
2807+
run_id='manual__' + now.isoformat(),
2808+
state=State.FAILED,
2809+
)
2810+
SkipMixin().skip(
2811+
dag_run=dag_run,
2812+
execution_date=now,
2813+
tasks=tasks,
2814+
session=session)
2815+
2816+
session.query(TI).filter(
2817+
TI.dag_id == 'dag',
2818+
TI.task_id == 'task',
2819+
TI.state == State.SKIPPED,
2820+
TI.start_date == now,
2821+
TI.end_date == now,
2822+
).one()
2823+
2824+
@patch('airflow.models.timezone.utcnow')
2825+
def test_skip_none_dagrun(self, mock_now):
2826+
session = settings.Session()
2827+
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
2828+
mock_now.return_value = now
2829+
dag = DAG(
2830+
'dag',
2831+
start_date=DEFAULT_DATE,
2832+
)
2833+
with dag:
2834+
tasks = [DummyOperator(task_id='task')]
2835+
SkipMixin().skip(
2836+
dag_run=None,
2837+
execution_date=now,
2838+
tasks=tasks,
2839+
session=session)
2840+
2841+
session.query(TI).filter(
2842+
TI.dag_id == 'dag',
2843+
TI.task_id == 'task',
2844+
TI.state == State.SKIPPED,
2845+
TI.start_date == now,
2846+
TI.end_date == now,
2847+
).one()
2848+
2849+
def test_skip_none_tasks(self):
2850+
session = Mock()
2851+
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session)
2852+
self.assertFalse(session.query.called)
2853+
self.assertFalse(session.commit.called)
2854+
2855+
2856+
class TestKubeResourceVersion(unittest.TestCase):
2857+
2858+
def test_checkpoint_resource_version(self):
2859+
session = settings.Session()
2860+
KubeResourceVersion.checkpoint_resource_version('7', session)
2861+
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7')
2862+
2863+
def test_reset_resource_version(self):
2864+
session = settings.Session()
2865+
version = KubeResourceVersion.reset_resource_version(session)
2866+
self.assertEqual(version, '0')
2867+
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0')
2868+
2869+
2870+
class TestKubeWorkerIdentifier(unittest.TestCase):
2871+
2872+
@patch('airflow.models.uuid.uuid4')
2873+
def test_get_or_create_not_exist(self, mock_uuid):
2874+
session = settings.Session()
2875+
session.query(KubeWorkerIdentifier).update({
2876+
KubeWorkerIdentifier.worker_uuid: ''
2877+
})
2878+
mock_uuid.return_value = 'abcde'
2879+
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
2880+
self.assertEqual(worker_uuid, 'abcde')
2881+
2882+
def test_get_or_create_exist(self):
2883+
session = settings.Session()
2884+
KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session)
2885+
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
2886+
self.assertEqual(worker_uuid, 'fghij')

0 commit comments

Comments
 (0)