diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index c0a3ed5f9a51c..cd48303a825d6 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -274,7 +274,6 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): """ if engine_kwargs is None: engine_kwargs = {} - engine_kwargs["creator"] = self.get_conn try: url = self.sqlalchemy_url diff --git a/airflow/providers/jdbc/hooks/jdbc.py b/airflow/providers/jdbc/hooks/jdbc.py index 27a438ae414cf..356bd5d450606 100644 --- a/airflow/providers/jdbc/hooks/jdbc.py +++ b/airflow/providers/jdbc/hooks/jdbc.py @@ -163,6 +163,19 @@ def sqlalchemy_url(self) -> URL: database=conn.schema, ) + def get_sqlalchemy_engine(self, engine_kwargs=None): + """ + Get an sqlalchemy_engine object. + + :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`. + :return: the created engine. + """ + if engine_kwargs is None: + engine_kwargs = {} + engine_kwargs["creator"] = self.get_conn + + return super().get_sqlalchemy_engine(engine_kwargs) + def get_conn(self) -> jaydebeapi.Connection: conn: Connection = self.get_connection(self.get_conn_id()) host: str = conn.host diff --git a/tests/providers/jdbc/hooks/test_jdbc.py b/tests/providers/jdbc/hooks/test_jdbc.py index cb38ce40ae391..f26a9d7ffb5b3 100644 --- a/tests/providers/jdbc/hooks/test_jdbc.py +++ b/tests/providers/jdbc/hooks/test_jdbc.py @@ -19,6 +19,7 @@ import json import logging +import sqlite3 from unittest import mock from unittest.mock import Mock, patch @@ -36,19 +37,30 @@ jdbc_conn_mock = Mock(name="jdbc_conn") -def get_hook(hook_params=None, conn_params=None): +def get_hook( + hook_params=None, + conn_params=None, + login: str | None = "login", + password: str | None = "password", + host: str | None = "host", + schema: str | None = "schema", + port: int | None = 1234, +): hook_params = hook_params or {} conn_params = conn_params or {} connection = Connection( **{ - **dict(login="login", password="password", host="host", schema="schema", port=1234), + **dict(login=login, password=password, host=host, schema=schema, port=port), **conn_params, } ) - hook = JdbcHook(**hook_params) - hook.get_connection = Mock() - hook.get_connection.return_value = connection + class MockedJdbcHook(JdbcHook): + @classmethod + def get_connection(cls, conn_id: str) -> Connection: + return connection + + hook = MockedJdbcHook(**hook_params) return hook @@ -201,3 +213,18 @@ def test_sqlalchemy_url_with_sqlalchemy_scheme(self): hook = get_hook(conn_params=conn_params, hook_params=hook_params) assert str(hook.sqlalchemy_url) == "mssql://login:password@host:1234/schema" + + def test_get_sqlalchemy_engine_verify_creator_is_being_used(self): + jdbc_hook = get_hook( + conn_params=dict(extra={"sqlalchemy_scheme": "sqlite"}), + login=None, + password=None, + host=None, + schema=":memory:", + port=None, + ) + + with sqlite3.connect(":memory:") as connection: + jdbc_hook.get_conn = lambda: connection + engine = jdbc_hook.get_sqlalchemy_engine() + assert engine.connect().connection.connection == connection