|
17 | 17 | # specific language governing permissions and limitations
|
18 | 18 | # under the License.
|
19 | 19 |
|
20 |
| -from airflow.models import TaskInstance |
| 20 | +from airflow.models import TaskInstance, DagRun |
21 | 21 | from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
22 | 22 | from airflow.utils.db import provide_session
|
23 | 23 | from airflow.utils.decorators import apply_defaults
|
|
26 | 26 |
|
27 | 27 | class ExternalTaskSensor(BaseSensorOperator):
|
28 | 28 | """
|
29 |
| - Waits for a task to complete in a different DAG |
| 29 | + Waits for a different DAG or a task in in a different DAG to complete |
30 | 30 |
|
31 | 31 | :param external_dag_id: The dag_id that contains the task you want to
|
32 | 32 | wait for
|
33 | 33 | :type external_dag_id: str
|
34 | 34 | :param external_task_id: The task_id that contains the task you want to
|
35 |
| - wait for |
| 35 | + wait for. If ``None`` the sensor waits for the DAG |
36 | 36 | :type external_task_id: str
|
37 | 37 | :param allowed_states: list of allowed states, default is ``['success']``
|
38 | 38 | :type allowed_states: list
|
39 | 39 | :param execution_delta: time difference with the previous execution to
|
40 |
| - look at, the default is the same execution_date as the current task. |
| 40 | + look at, the default is the same execution_date as the current task or DAG. |
41 | 41 | For yesterday, use [positive!] datetime.timedelta(days=1). Either
|
42 | 42 | execution_delta or execution_date_fn can be passed to
|
43 | 43 | ExternalTaskSensor, but not both.
|
@@ -102,13 +102,23 @@ def poke(self, context, session=None):
|
102 | 102 | '{self.external_dag_id}.'
|
103 | 103 | '{self.external_task_id} on '
|
104 | 104 | '{} ... '.format(serialized_dttm_filter, **locals()))
|
105 |
| - TI = TaskInstance |
106 | 105 |
|
107 |
| - count = session.query(TI).filter( |
108 |
| - TI.dag_id == self.external_dag_id, |
109 |
| - TI.task_id == self.external_task_id, |
110 |
| - TI.state.in_(self.allowed_states), |
111 |
| - TI.execution_date.in_(dttm_filter), |
112 |
| - ).count() |
| 106 | + if self.external_task_id: |
| 107 | + TI = TaskInstance |
| 108 | + |
| 109 | + count = session.query(TI).filter( |
| 110 | + TI.dag_id == self.external_dag_id, |
| 111 | + TI.task_id == self.external_task_id, |
| 112 | + TI.state.in_(self.allowed_states), |
| 113 | + TI.execution_date.in_(dttm_filter), |
| 114 | + ).count() |
| 115 | + else: |
| 116 | + DR = DagRun |
| 117 | + count = session.query(DR).filter( |
| 118 | + DR.dag_id == self.external_dag_id, |
| 119 | + DR.state.in_(self.allowed_states), |
| 120 | + DR.execution_date.in_(dttm_filter), |
| 121 | + ).count() |
| 122 | + |
113 | 123 | session.commit()
|
114 | 124 | return count == len(dttm_filter)
|
0 commit comments