Skip to content

Commit f56462d

Browse files
jmcarpChris Fei
authored and
Chris Fei
committed
Improve test coverage of airflow.models. (apache#3982)
1 parent 69585f8 commit f56462d

File tree

2 files changed

+180
-26
lines changed

2 files changed

+180
-26
lines changed

airflow/models.py

-25
Original file line numberDiff line numberDiff line change
@@ -625,21 +625,6 @@ def dagbag_report(self):
625625
table=pprinttable(stats),
626626
)
627627

628-
@provide_session
629-
def deactivate_inactive_dags(self, session=None):
630-
active_dag_ids = [dag.dag_id for dag in list(self.dags.values())]
631-
for dag in session.query(
632-
DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
633-
dag.is_active = False
634-
session.merge(dag)
635-
session.commit()
636-
637-
@provide_session
638-
def paused_dags(self, session=None):
639-
dag_ids = [dp.dag_id for dp in session.query(DagModel).filter(
640-
DagModel.is_paused.__eq__(True))]
641-
return dag_ids
642-
643628

644629
class User(Base):
645630
__tablename__ = "users"
@@ -4202,16 +4187,6 @@ def add_tasks(self, tasks):
42024187
for task in tasks:
42034188
self.add_task(task)
42044189

4205-
@provide_session
4206-
def db_merge(self, session=None):
4207-
BO = BaseOperator
4208-
tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all()
4209-
for t in tasks:
4210-
session.delete(t)
4211-
session.commit()
4212-
session.merge(self)
4213-
session.commit()
4214-
42154190
def run(
42164191
self,
42174192
start_date=None,

tests/models.py

+180-1
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,20 @@
4444
from airflow.models import clear_task_instances
4545
from airflow.models import XCom
4646
from airflow.models import Connection
47+
from airflow.models import SkipMixin
48+
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier
4749
from airflow.jobs import LocalTaskJob
4850
from airflow.operators.dummy_operator import DummyOperator
4951
from airflow.operators.bash_operator import BashOperator
5052
from airflow.operators.python_operator import PythonOperator
5153
from airflow.operators.python_operator import ShortCircuitOperator
54+
from airflow.operators.subdag_operator import SubDagOperator
5255
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
5356
from airflow.utils import timezone
5457
from airflow.utils.weight_rule import WeightRule
5558
from airflow.utils.state import State
5659
from airflow.utils.trigger_rule import TriggerRule
57-
from mock import patch, ANY
60+
from mock import patch, Mock, ANY
5861
from parameterized import parameterized
5962
from tempfile import mkdtemp, NamedTemporaryFile
6063

@@ -640,6 +643,38 @@ def test_following_previous_schedule_daily_dag_CET_to_CEST(self):
640643
self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
641644
self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
642645

646+
@patch('airflow.models.timezone.utcnow')
647+
def test_sync_to_db(self, mock_now):
648+
dag = DAG(
649+
'dag',
650+
start_date=DEFAULT_DATE,
651+
)
652+
with dag:
653+
DummyOperator(task_id='task', owner='owner1')
654+
SubDagOperator(
655+
task_id='subtask',
656+
owner='owner2',
657+
subdag=DAG(
658+
'dag.subtask',
659+
start_date=DEFAULT_DATE,
660+
)
661+
)
662+
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
663+
mock_now.return_value = now
664+
session = settings.Session()
665+
dag.sync_to_db(session=session)
666+
667+
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
668+
self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'})
669+
self.assertEqual(orm_dag.last_scheduler_run, now)
670+
self.assertTrue(orm_dag.is_active)
671+
672+
orm_subdag = session.query(DagModel).filter(
673+
DagModel.dag_id == 'dag.subtask').one()
674+
self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'})
675+
self.assertEqual(orm_subdag.last_scheduler_run, now)
676+
self.assertTrue(orm_subdag.is_active)
677+
643678

644679
class DagStatTest(unittest.TestCase):
645680
def test_dagstats_crud(self):
@@ -690,6 +725,25 @@ def test_dagstats_crud(self):
690725
for stat in res:
691726
self.assertFalse(stat.dirty)
692727

728+
def test_update_exception(self):
729+
session = Mock()
730+
(session.query.return_value
731+
.filter.return_value
732+
.with_for_update.return_value
733+
.all.side_effect) = RuntimeError('it broke')
734+
DagStat.update(session=session)
735+
session.rollback.assert_called()
736+
737+
def test_set_dirty_exception(self):
738+
session = Mock()
739+
session.query.return_value.filter.return_value.all.return_value = []
740+
(session.query.return_value
741+
.filter.return_value
742+
.with_for_update.return_value
743+
.all.side_effect) = RuntimeError('it broke')
744+
DagStat.set_dirty('dag', session)
745+
session.rollback.assert_called()
746+
693747

694748
class DagRunTest(unittest.TestCase):
695749

@@ -2465,6 +2519,35 @@ def success_handler(self, context):
24652519
ti.refresh_from_db()
24662520
self.assertEqual(ti.state, State.SUCCESS)
24672521

2522+
@patch('airflow.models.send_email')
2523+
def test_email_alert(self, mock_send_email):
2524+
task = DummyOperator(task_id='op', email='test@test.test')
2525+
ti = TI(task=task, execution_date=datetime.datetime.now())
2526+
ti.email_alert(RuntimeError('it broke'))
2527+
2528+
self.assertTrue(mock_send_email.called)
2529+
(email, title, body), _ = mock_send_email.call_args
2530+
self.assertEqual(email, 'test@test.test')
2531+
self.assertIn(repr(ti), title)
2532+
self.assertIn('it broke', body)
2533+
2534+
def test_set_duration(self):
2535+
task = DummyOperator(task_id='op', email='test@test.test')
2536+
ti = TI(
2537+
task=task,
2538+
execution_date=datetime.datetime.now(),
2539+
)
2540+
ti.start_date = datetime.datetime(2018, 10, 1, 1)
2541+
ti.end_date = datetime.datetime(2018, 10, 1, 2)
2542+
ti.set_duration()
2543+
self.assertEqual(ti.duration, 3600)
2544+
2545+
def test_set_duration_empty_dates(self):
2546+
task = DummyOperator(task_id='op', email='test@test.test')
2547+
ti = TI(task=task, execution_date=datetime.datetime.now())
2548+
ti.set_duration()
2549+
self.assertIsNone(ti.duration)
2550+
24682551

24692552
class ClearTasksTest(unittest.TestCase):
24702553

@@ -2819,3 +2902,99 @@ def test_connection_from_uri_with_extras(self):
28192902
self.assertEqual(connection.port, 1234)
28202903
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
28212904
'extra2': '/path/'})
2905+
2906+
2907+
class TestSkipMixin(unittest.TestCase):
2908+
2909+
@patch('airflow.models.timezone.utcnow')
2910+
def test_skip(self, mock_now):
2911+
session = settings.Session()
2912+
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
2913+
mock_now.return_value = now
2914+
dag = DAG(
2915+
'dag',
2916+
start_date=DEFAULT_DATE,
2917+
)
2918+
with dag:
2919+
tasks = [DummyOperator(task_id='task')]
2920+
dag_run = dag.create_dagrun(
2921+
run_id='manual__' + now.isoformat(),
2922+
state=State.FAILED,
2923+
)
2924+
SkipMixin().skip(
2925+
dag_run=dag_run,
2926+
execution_date=now,
2927+
tasks=tasks,
2928+
session=session)
2929+
2930+
session.query(TI).filter(
2931+
TI.dag_id == 'dag',
2932+
TI.task_id == 'task',
2933+
TI.state == State.SKIPPED,
2934+
TI.start_date == now,
2935+
TI.end_date == now,
2936+
).one()
2937+
2938+
@patch('airflow.models.timezone.utcnow')
2939+
def test_skip_none_dagrun(self, mock_now):
2940+
session = settings.Session()
2941+
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
2942+
mock_now.return_value = now
2943+
dag = DAG(
2944+
'dag',
2945+
start_date=DEFAULT_DATE,
2946+
)
2947+
with dag:
2948+
tasks = [DummyOperator(task_id='task')]
2949+
SkipMixin().skip(
2950+
dag_run=None,
2951+
execution_date=now,
2952+
tasks=tasks,
2953+
session=session)
2954+
2955+
session.query(TI).filter(
2956+
TI.dag_id == 'dag',
2957+
TI.task_id == 'task',
2958+
TI.state == State.SKIPPED,
2959+
TI.start_date == now,
2960+
TI.end_date == now,
2961+
).one()
2962+
2963+
def test_skip_none_tasks(self):
2964+
session = Mock()
2965+
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session)
2966+
self.assertFalse(session.query.called)
2967+
self.assertFalse(session.commit.called)
2968+
2969+
2970+
class TestKubeResourceVersion(unittest.TestCase):
2971+
2972+
def test_checkpoint_resource_version(self):
2973+
session = settings.Session()
2974+
KubeResourceVersion.checkpoint_resource_version('7', session)
2975+
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7')
2976+
2977+
def test_reset_resource_version(self):
2978+
session = settings.Session()
2979+
version = KubeResourceVersion.reset_resource_version(session)
2980+
self.assertEqual(version, '0')
2981+
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0')
2982+
2983+
2984+
class TestKubeWorkerIdentifier(unittest.TestCase):
2985+
2986+
@patch('airflow.models.uuid.uuid4')
2987+
def test_get_or_create_not_exist(self, mock_uuid):
2988+
session = settings.Session()
2989+
session.query(KubeWorkerIdentifier).update({
2990+
KubeWorkerIdentifier.worker_uuid: ''
2991+
})
2992+
mock_uuid.return_value = 'abcde'
2993+
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
2994+
self.assertEqual(worker_uuid, 'abcde')
2995+
2996+
def test_get_or_create_exist(self):
2997+
session = settings.Session()
2998+
KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session)
2999+
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
3000+
self.assertEqual(worker_uuid, 'fghij')

0 commit comments

Comments
 (0)