Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SqlSensor enhancement: #37437 #43107

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions providers/src/airflow/providers/common/sql/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

from operator import itemgetter
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -46,10 +47,12 @@ class SqlSensor(BaseSensorOperator):
:param sql: The SQL to run. To pass, it needs to return at least one cell
that contains a non-zero / empty string value.
:param parameters: The parameters to render the SQL query with (optional).
:param success: Success criteria for the sensor is a Callable that takes the first_cell's value
as the only argument, and returns a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes the first_cell's value
as the only argument and returns a boolean (optional).
:param success: Success criteria for the sensor is a Callable that takes the output
of selector as the only argument, and returns a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes the output
of selector as the only argument and returns a boolean (optional).
:param selector: Function which takes the resulting row and transforms it before
it is passed to success or failure (optional). Takes the first cell by default.
:param fail_on_empty: Explicitly fail on no rows returned.
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
Expand All @@ -67,6 +70,7 @@ def __init__(
parameters: Mapping[str, Any] | None = None,
success: Callable[[Any], bool] | None = None,
failure: Callable[[Any], bool] | None = None,
selector: Callable[[tuple[Any]], Any] | None = itemgetter(0),
fail_on_empty: bool = False,
hook_params: Mapping[str, Any] | None = None,
**kwargs,
Expand All @@ -76,6 +80,7 @@ def __init__(
self.parameters = parameters
self.success = success
self.failure = failure
self.selector = selector
self.fail_on_empty = fail_on_empty
self.hook_params = hook_params
super().__init__(**kwargs)
Expand All @@ -102,20 +107,20 @@ def poke(self, context: Context) -> bool:
else:
return False

first_cell = records[0][0]
condition = self.selector(records[0])
if self.failure is not None:
if callable(self.failure):
if self.failure(first_cell):
message = f"Failure criteria met. self.failure({first_cell}) returned True"
if self.failure(condition):
message = f"Failure criteria met. self.failure({condition}) returned True"
raise AirflowException(message)
else:
message = f"self.failure is present, but not callable -> {self.failure}"
raise AirflowException(message)

if self.success is not None:
if callable(self.success):
return self.success(first_cell)
return self.success(condition)
else:
message = f"self.success is present, but not callable -> {self.success}"
raise AirflowException(message)
return bool(first_cell)
return bool(condition)
25 changes: 25 additions & 0 deletions providers/tests/common/sql/sensors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,31 @@ def test_sql_sensor_postgres_poke_invalid_success(
op.poke({})
assert "self.success is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.backend("postgres")
def test_sql_sensor_postgres_with_selector(self):
op1 = SqlSensor(
task_id="sql_sensor_check_1",
conn_id="postgres_default",
sql="SELECT 0, 1",
dag=self.dag,
success=lambda x: x in [1],
failure=lambda x: x in [0],
selector=lambda x: x[1],
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

op2 = SqlSensor(
task_id="sql_sensor_check_2",
conn_id="postgres_default",
sql="SELECT 0, 1",
dag=self.dag,
success=lambda x: x in [1],
failure=lambda x: x in [0],
selector=lambda x: x[0],
)
with pytest.raises(AirflowException):
op2.poke({})

@pytest.mark.db_test
def test_sql_sensor_hook_params(self):
op = SqlSensor(
Expand Down