|
44 | 44 | from airflow.models import clear_task_instances
|
45 | 45 | from airflow.models import XCom
|
46 | 46 | from airflow.models import Connection
|
| 47 | +from airflow.models import SkipMixin |
| 48 | +from airflow.models import KubeResourceVersion, KubeWorkerIdentifier |
47 | 49 | from airflow.jobs import LocalTaskJob
|
48 | 50 | from airflow.operators.dummy_operator import DummyOperator
|
49 | 51 | from airflow.operators.bash_operator import BashOperator
|
50 | 52 | from airflow.operators.python_operator import PythonOperator
|
51 | 53 | from airflow.operators.python_operator import ShortCircuitOperator
|
| 54 | +from airflow.operators.subdag_operator import SubDagOperator |
52 | 55 | from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
|
53 | 56 | from airflow.utils import timezone
|
54 | 57 | from airflow.utils.weight_rule import WeightRule
|
55 | 58 | from airflow.utils.state import State
|
56 | 59 | from airflow.utils.trigger_rule import TriggerRule
|
57 |
| -from mock import patch, ANY |
| 60 | +from mock import patch, Mock, ANY |
58 | 61 | from parameterized import parameterized
|
59 | 62 | from tempfile import mkdtemp, NamedTemporaryFile
|
60 | 63 |
|
@@ -640,6 +643,38 @@ def test_following_previous_schedule_daily_dag_CET_to_CEST(self):
|
640 | 643 | self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
|
641 | 644 | self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
|
642 | 645 |
|
| 646 | + @patch('airflow.models.timezone.utcnow') |
| 647 | + def test_sync_to_db(self, mock_now): |
| 648 | + dag = DAG( |
| 649 | + 'dag', |
| 650 | + start_date=DEFAULT_DATE, |
| 651 | + ) |
| 652 | + with dag: |
| 653 | + DummyOperator(task_id='task', owner='owner1') |
| 654 | + SubDagOperator( |
| 655 | + task_id='subtask', |
| 656 | + owner='owner2', |
| 657 | + subdag=DAG( |
| 658 | + 'dag.subtask', |
| 659 | + start_date=DEFAULT_DATE, |
| 660 | + ) |
| 661 | + ) |
| 662 | + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) |
| 663 | + mock_now.return_value = now |
| 664 | + session = settings.Session() |
| 665 | + dag.sync_to_db(session=session) |
| 666 | + |
| 667 | + orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one() |
| 668 | + self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'}) |
| 669 | + self.assertEqual(orm_dag.last_scheduler_run, now) |
| 670 | + self.assertTrue(orm_dag.is_active) |
| 671 | + |
| 672 | + orm_subdag = session.query(DagModel).filter( |
| 673 | + DagModel.dag_id == 'dag.subtask').one() |
| 674 | + self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'}) |
| 675 | + self.assertEqual(orm_subdag.last_scheduler_run, now) |
| 676 | + self.assertTrue(orm_subdag.is_active) |
| 677 | + |
643 | 678 |
|
644 | 679 | class DagStatTest(unittest.TestCase):
|
645 | 680 | def test_dagstats_crud(self):
|
@@ -690,6 +725,25 @@ def test_dagstats_crud(self):
|
690 | 725 | for stat in res:
|
691 | 726 | self.assertFalse(stat.dirty)
|
692 | 727 |
|
| 728 | + def test_update_exception(self): |
| 729 | + session = Mock() |
| 730 | + (session.query.return_value |
| 731 | + .filter.return_value |
| 732 | + .with_for_update.return_value |
| 733 | + .all.side_effect) = RuntimeError('it broke') |
| 734 | + DagStat.update(session=session) |
| 735 | + session.rollback.assert_called() |
| 736 | + |
| 737 | + def test_set_dirty_exception(self): |
| 738 | + session = Mock() |
| 739 | + session.query.return_value.filter.return_value.all.return_value = [] |
| 740 | + (session.query.return_value |
| 741 | + .filter.return_value |
| 742 | + .with_for_update.return_value |
| 743 | + .all.side_effect) = RuntimeError('it broke') |
| 744 | + DagStat.set_dirty('dag', session) |
| 745 | + session.rollback.assert_called() |
| 746 | + |
693 | 747 |
|
694 | 748 | class DagRunTest(unittest.TestCase):
|
695 | 749 |
|
@@ -2465,6 +2519,35 @@ def success_handler(self, context):
|
2465 | 2519 | ti.refresh_from_db()
|
2466 | 2520 | self.assertEqual(ti.state, State.SUCCESS)
|
2467 | 2521 |
|
| 2522 | + @patch('airflow.models.send_email') |
| 2523 | + def test_email_alert(self, mock_send_email): |
| 2524 | + task = DummyOperator(task_id='op', email='test@test.test') |
| 2525 | + ti = TI(task=task, execution_date=datetime.datetime.now()) |
| 2526 | + ti.email_alert(RuntimeError('it broke')) |
| 2527 | + |
| 2528 | + self.assertTrue(mock_send_email.called) |
| 2529 | + (email, title, body), _ = mock_send_email.call_args |
| 2530 | + self.assertEqual(email, 'test@test.test') |
| 2531 | + self.assertIn(repr(ti), title) |
| 2532 | + self.assertIn('it broke', body) |
| 2533 | + |
| 2534 | + def test_set_duration(self): |
| 2535 | + task = DummyOperator(task_id='op', email='test@test.test') |
| 2536 | + ti = TI( |
| 2537 | + task=task, |
| 2538 | + execution_date=datetime.datetime.now(), |
| 2539 | + ) |
| 2540 | + ti.start_date = datetime.datetime(2018, 10, 1, 1) |
| 2541 | + ti.end_date = datetime.datetime(2018, 10, 1, 2) |
| 2542 | + ti.set_duration() |
| 2543 | + self.assertEqual(ti.duration, 3600) |
| 2544 | + |
| 2545 | + def test_set_duration_empty_dates(self): |
| 2546 | + task = DummyOperator(task_id='op', email='test@test.test') |
| 2547 | + ti = TI(task=task, execution_date=datetime.datetime.now()) |
| 2548 | + ti.set_duration() |
| 2549 | + self.assertIsNone(ti.duration) |
| 2550 | + |
2468 | 2551 |
|
2469 | 2552 | class ClearTasksTest(unittest.TestCase):
|
2470 | 2553 |
|
@@ -2819,3 +2902,99 @@ def test_connection_from_uri_with_extras(self):
|
2819 | 2902 | self.assertEqual(connection.port, 1234)
|
2820 | 2903 | self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
|
2821 | 2904 | 'extra2': '/path/'})
|
| 2905 | + |
| 2906 | + |
| 2907 | +class TestSkipMixin(unittest.TestCase): |
| 2908 | + |
| 2909 | + @patch('airflow.models.timezone.utcnow') |
| 2910 | + def test_skip(self, mock_now): |
| 2911 | + session = settings.Session() |
| 2912 | + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) |
| 2913 | + mock_now.return_value = now |
| 2914 | + dag = DAG( |
| 2915 | + 'dag', |
| 2916 | + start_date=DEFAULT_DATE, |
| 2917 | + ) |
| 2918 | + with dag: |
| 2919 | + tasks = [DummyOperator(task_id='task')] |
| 2920 | + dag_run = dag.create_dagrun( |
| 2921 | + run_id='manual__' + now.isoformat(), |
| 2922 | + state=State.FAILED, |
| 2923 | + ) |
| 2924 | + SkipMixin().skip( |
| 2925 | + dag_run=dag_run, |
| 2926 | + execution_date=now, |
| 2927 | + tasks=tasks, |
| 2928 | + session=session) |
| 2929 | + |
| 2930 | + session.query(TI).filter( |
| 2931 | + TI.dag_id == 'dag', |
| 2932 | + TI.task_id == 'task', |
| 2933 | + TI.state == State.SKIPPED, |
| 2934 | + TI.start_date == now, |
| 2935 | + TI.end_date == now, |
| 2936 | + ).one() |
| 2937 | + |
| 2938 | + @patch('airflow.models.timezone.utcnow') |
| 2939 | + def test_skip_none_dagrun(self, mock_now): |
| 2940 | + session = settings.Session() |
| 2941 | + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) |
| 2942 | + mock_now.return_value = now |
| 2943 | + dag = DAG( |
| 2944 | + 'dag', |
| 2945 | + start_date=DEFAULT_DATE, |
| 2946 | + ) |
| 2947 | + with dag: |
| 2948 | + tasks = [DummyOperator(task_id='task')] |
| 2949 | + SkipMixin().skip( |
| 2950 | + dag_run=None, |
| 2951 | + execution_date=now, |
| 2952 | + tasks=tasks, |
| 2953 | + session=session) |
| 2954 | + |
| 2955 | + session.query(TI).filter( |
| 2956 | + TI.dag_id == 'dag', |
| 2957 | + TI.task_id == 'task', |
| 2958 | + TI.state == State.SKIPPED, |
| 2959 | + TI.start_date == now, |
| 2960 | + TI.end_date == now, |
| 2961 | + ).one() |
| 2962 | + |
| 2963 | + def test_skip_none_tasks(self): |
| 2964 | + session = Mock() |
| 2965 | + SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session) |
| 2966 | + self.assertFalse(session.query.called) |
| 2967 | + self.assertFalse(session.commit.called) |
| 2968 | + |
| 2969 | + |
| 2970 | +class TestKubeResourceVersion(unittest.TestCase): |
| 2971 | + |
| 2972 | + def test_checkpoint_resource_version(self): |
| 2973 | + session = settings.Session() |
| 2974 | + KubeResourceVersion.checkpoint_resource_version('7', session) |
| 2975 | + self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7') |
| 2976 | + |
| 2977 | + def test_reset_resource_version(self): |
| 2978 | + session = settings.Session() |
| 2979 | + version = KubeResourceVersion.reset_resource_version(session) |
| 2980 | + self.assertEqual(version, '0') |
| 2981 | + self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0') |
| 2982 | + |
| 2983 | + |
| 2984 | +class TestKubeWorkerIdentifier(unittest.TestCase): |
| 2985 | + |
| 2986 | + @patch('airflow.models.uuid.uuid4') |
| 2987 | + def test_get_or_create_not_exist(self, mock_uuid): |
| 2988 | + session = settings.Session() |
| 2989 | + session.query(KubeWorkerIdentifier).update({ |
| 2990 | + KubeWorkerIdentifier.worker_uuid: '' |
| 2991 | + }) |
| 2992 | + mock_uuid.return_value = 'abcde' |
| 2993 | + worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) |
| 2994 | + self.assertEqual(worker_uuid, 'abcde') |
| 2995 | + |
| 2996 | + def test_get_or_create_exist(self): |
| 2997 | + session = settings.Session() |
| 2998 | + KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session) |
| 2999 | + worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) |
| 3000 | + self.assertEqual(worker_uuid, 'fghij') |
0 commit comments