|
43 | 43 | from airflow.models import clear_task_instances
|
44 | 44 | from airflow.models import XCom
|
45 | 45 | from airflow.models import Connection
|
| 46 | +from airflow.models import SkipMixin |
| 47 | +from airflow.models import KubeResourceVersion, KubeWorkerIdentifier |
46 | 48 | from airflow.jobs import LocalTaskJob
|
47 | 49 | from airflow.operators.dummy_operator import DummyOperator
|
48 | 50 | from airflow.operators.bash_operator import BashOperator
|
49 | 51 | from airflow.operators.python_operator import PythonOperator
|
50 | 52 | from airflow.operators.python_operator import ShortCircuitOperator
|
| 53 | +from airflow.operators.subdag_operator import SubDagOperator |
51 | 54 | from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
|
52 | 55 | from airflow.utils import timezone
|
53 | 56 | from airflow.utils.weight_rule import WeightRule
|
54 | 57 | from airflow.utils.state import State
|
55 | 58 | from airflow.utils.trigger_rule import TriggerRule
|
56 |
| -from mock import patch, ANY |
| 59 | +from mock import patch, Mock, ANY |
57 | 60 | from parameterized import parameterized
|
58 | 61 | from tempfile import mkdtemp, NamedTemporaryFile
|
59 | 62 |
|
@@ -575,6 +578,38 @@ def test_cycle(self):
|
575 | 578 | with self.assertRaises(AirflowDagCycleException):
|
576 | 579 | dag.test_cycle()
|
577 | 580 |
|
| 581 | + @patch('airflow.models.timezone.utcnow') |
| 582 | + def test_sync_to_db(self, mock_now): |
| 583 | + dag = DAG( |
| 584 | + 'dag', |
| 585 | + start_date=DEFAULT_DATE, |
| 586 | + ) |
| 587 | + with dag: |
| 588 | + DummyOperator(task_id='task', owner='owner1') |
| 589 | + SubDagOperator( |
| 590 | + task_id='subtask', |
| 591 | + owner='owner2', |
| 592 | + subdag=DAG( |
| 593 | + 'dag.subtask', |
| 594 | + start_date=DEFAULT_DATE, |
| 595 | + ) |
| 596 | + ) |
| 597 | + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) |
| 598 | + mock_now.return_value = now |
| 599 | + session = settings.Session() |
| 600 | + dag.sync_to_db(session=session) |
| 601 | + |
| 602 | + orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one() |
| 603 | + self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'}) |
| 604 | + self.assertEqual(orm_dag.last_scheduler_run, now) |
| 605 | + self.assertTrue(orm_dag.is_active) |
| 606 | + |
| 607 | + orm_subdag = session.query(DagModel).filter( |
| 608 | + DagModel.dag_id == 'dag.subtask').one() |
| 609 | + self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'}) |
| 610 | + self.assertEqual(orm_subdag.last_scheduler_run, now) |
| 611 | + self.assertTrue(orm_subdag.is_active) |
| 612 | + |
578 | 613 |
|
579 | 614 | class DagStatTest(unittest.TestCase):
|
580 | 615 | def test_dagstats_crud(self):
|
@@ -625,6 +660,25 @@ def test_dagstats_crud(self):
|
625 | 660 | for stat in res:
|
626 | 661 | self.assertFalse(stat.dirty)
|
627 | 662 |
|
| 663 | + def test_update_exception(self): |
| 664 | + session = Mock() |
| 665 | + (session.query.return_value |
| 666 | + .filter.return_value |
| 667 | + .with_for_update.return_value |
| 668 | + .all.side_effect) = RuntimeError('it broke') |
| 669 | + DagStat.update(session=session) |
| 670 | + session.rollback.assert_called() |
| 671 | + |
| 672 | + def test_set_dirty_exception(self): |
| 673 | + session = Mock() |
| 674 | + session.query.return_value.filter.return_value.all.return_value = [] |
| 675 | + (session.query.return_value |
| 676 | + .filter.return_value |
| 677 | + .with_for_update.return_value |
| 678 | + .all.side_effect) = RuntimeError('it broke') |
| 679 | + DagStat.set_dirty('dag', session) |
| 680 | + session.rollback.assert_called() |
| 681 | + |
628 | 682 |
|
629 | 683 | class DagRunTest(unittest.TestCase):
|
630 | 684 |
|
@@ -2349,6 +2403,35 @@ def test_overwrite_params_with_dag_run_conf_none(self):
|
2349 | 2403 |
|
2350 | 2404 | self.assertEqual(False, params["override"])
|
2351 | 2405 |
|
| 2406 | + @patch('airflow.models.send_email') |
| 2407 | + def test_email_alert(self, mock_send_email): |
| 2408 | + task = DummyOperator(task_id='op', email='test@test.test') |
| 2409 | + ti = TI(task=task, execution_date=datetime.datetime.now()) |
| 2410 | + ti.email_alert(RuntimeError('it broke')) |
| 2411 | + |
| 2412 | + self.assertTrue(mock_send_email.called) |
| 2413 | + (email, title, body), _ = mock_send_email.call_args |
| 2414 | + self.assertEqual(email, 'test@test.test') |
| 2415 | + self.assertIn(repr(ti), title) |
| 2416 | + self.assertIn('it broke', body) |
| 2417 | + |
| 2418 | + def test_set_duration(self): |
| 2419 | + task = DummyOperator(task_id='op', email='test@test.test') |
| 2420 | + ti = TI( |
| 2421 | + task=task, |
| 2422 | + execution_date=datetime.datetime.now(), |
| 2423 | + ) |
| 2424 | + ti.start_date = datetime.datetime(2018, 10, 1, 1) |
| 2425 | + ti.end_date = datetime.datetime(2018, 10, 1, 2) |
| 2426 | + ti.set_duration() |
| 2427 | + self.assertEqual(ti.duration, 3600) |
| 2428 | + |
| 2429 | + def test_set_duration_empty_dates(self): |
| 2430 | + task = DummyOperator(task_id='op', email='test@test.test') |
| 2431 | + ti = TI(task=task, execution_date=datetime.datetime.now()) |
| 2432 | + ti.set_duration() |
| 2433 | + self.assertIsNone(ti.duration) |
| 2434 | + |
2352 | 2435 |
|
2353 | 2436 | class ClearTasksTest(unittest.TestCase):
|
2354 | 2437 |
|
@@ -2705,3 +2788,99 @@ def test_connection_from_uri_with_extras(self):
|
2705 | 2788 | self.assertEqual(connection.port, 1234)
|
2706 | 2789 | self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
|
2707 | 2790 | 'extra2': '/path/'})
|
| 2791 | + |
| 2792 | + |
| 2793 | +class TestSkipMixin(unittest.TestCase): |
| 2794 | + |
| 2795 | + @patch('airflow.models.timezone.utcnow') |
| 2796 | + def test_skip(self, mock_now): |
| 2797 | + session = settings.Session() |
| 2798 | + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) |
| 2799 | + mock_now.return_value = now |
| 2800 | + dag = DAG( |
| 2801 | + 'dag', |
| 2802 | + start_date=DEFAULT_DATE, |
| 2803 | + ) |
| 2804 | + with dag: |
| 2805 | + tasks = [DummyOperator(task_id='task')] |
| 2806 | + dag_run = dag.create_dagrun( |
| 2807 | + run_id='manual__' + now.isoformat(), |
| 2808 | + state=State.FAILED, |
| 2809 | + ) |
| 2810 | + SkipMixin().skip( |
| 2811 | + dag_run=dag_run, |
| 2812 | + execution_date=now, |
| 2813 | + tasks=tasks, |
| 2814 | + session=session) |
| 2815 | + |
| 2816 | + session.query(TI).filter( |
| 2817 | + TI.dag_id == 'dag', |
| 2818 | + TI.task_id == 'task', |
| 2819 | + TI.state == State.SKIPPED, |
| 2820 | + TI.start_date == now, |
| 2821 | + TI.end_date == now, |
| 2822 | + ).one() |
| 2823 | + |
| 2824 | + @patch('airflow.models.timezone.utcnow') |
| 2825 | + def test_skip_none_dagrun(self, mock_now): |
| 2826 | + session = settings.Session() |
| 2827 | + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) |
| 2828 | + mock_now.return_value = now |
| 2829 | + dag = DAG( |
| 2830 | + 'dag', |
| 2831 | + start_date=DEFAULT_DATE, |
| 2832 | + ) |
| 2833 | + with dag: |
| 2834 | + tasks = [DummyOperator(task_id='task')] |
| 2835 | + SkipMixin().skip( |
| 2836 | + dag_run=None, |
| 2837 | + execution_date=now, |
| 2838 | + tasks=tasks, |
| 2839 | + session=session) |
| 2840 | + |
| 2841 | + session.query(TI).filter( |
| 2842 | + TI.dag_id == 'dag', |
| 2843 | + TI.task_id == 'task', |
| 2844 | + TI.state == State.SKIPPED, |
| 2845 | + TI.start_date == now, |
| 2846 | + TI.end_date == now, |
| 2847 | + ).one() |
| 2848 | + |
| 2849 | + def test_skip_none_tasks(self): |
| 2850 | + session = Mock() |
| 2851 | + SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session) |
| 2852 | + self.assertFalse(session.query.called) |
| 2853 | + self.assertFalse(session.commit.called) |
| 2854 | + |
| 2855 | + |
| 2856 | +class TestKubeResourceVersion(unittest.TestCase): |
| 2857 | + |
| 2858 | + def test_checkpoint_resource_version(self): |
| 2859 | + session = settings.Session() |
| 2860 | + KubeResourceVersion.checkpoint_resource_version('7', session) |
| 2861 | + self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7') |
| 2862 | + |
| 2863 | + def test_reset_resource_version(self): |
| 2864 | + session = settings.Session() |
| 2865 | + version = KubeResourceVersion.reset_resource_version(session) |
| 2866 | + self.assertEqual(version, '0') |
| 2867 | + self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0') |
| 2868 | + |
| 2869 | + |
| 2870 | +class TestKubeWorkerIdentifier(unittest.TestCase): |
| 2871 | + |
| 2872 | + @patch('airflow.models.uuid.uuid4') |
| 2873 | + def test_get_or_create_not_exist(self, mock_uuid): |
| 2874 | + session = settings.Session() |
| 2875 | + session.query(KubeWorkerIdentifier).update({ |
| 2876 | + KubeWorkerIdentifier.worker_uuid: '' |
| 2877 | + }) |
| 2878 | + mock_uuid.return_value = 'abcde' |
| 2879 | + worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) |
| 2880 | + self.assertEqual(worker_uuid, 'abcde') |
| 2881 | + |
| 2882 | + def test_get_or_create_exist(self): |
| 2883 | + session = settings.Session() |
| 2884 | + KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session) |
| 2885 | + worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) |
| 2886 | + self.assertEqual(worker_uuid, 'fghij') |
0 commit comments