diff --git a/example.py b/example.py index 2ec0f5a..632b34b 100644 --- a/example.py +++ b/example.py @@ -131,7 +131,7 @@ class OriginalError(Exception): def example_driver_for_sqlalchemy(): from sqlalchemy.engine import create_engine engine = create_engine( - 'mysql+pydataapi://', + 'mysql+pydataapi://', # or 'postgresql+pydataapi://', connect_args={ 'resource_arn': 'arn:aws:rds:us-east-1:123456789012:cluster:dummy', 'secret_arn': 'arn:aws:secretsmanager:us-east-1:123456789012:secret:dummy', diff --git a/pydataapi/dbapi.py b/pydataapi/dbapi.py index 344c40e..d945920 100644 --- a/pydataapi/dbapi.py +++ b/pydataapi/dbapi.py @@ -68,6 +68,8 @@ def rollback(self) -> None: self._data_api.rollback() def cursor(self) -> 'Cursor': + if not self._data_api.transaction_id: + self._data_api.begin() cursor = Cursor(self._data_api) self.cursors.append(cursor) diff --git a/scripts/fix_format.sh b/scripts/fix_format.sh deleted file mode 100755 index 71a772d..0000000 --- a/scripts/fix_format.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash -set -e - -black pydataapi tests --skip-string-normalization -isort --recursive -w 88 --combine-as --thirdparty pydataapi pydataapi tests -m 3 -tc diff --git a/tests/pydataapi/test_dbaapi.py b/tests/pydataapi/test_dbaapi.py index 02436b5..9e7d151 100644 --- a/tests/pydataapi/test_dbaapi.py +++ b/tests/pydataapi/test_dbaapi.py @@ -56,6 +56,7 @@ def test_rollback_not_called(mocked_client) -> None: def test_execute_insert(mocked_client, mocker) -> None: + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} mocked_client.execute_statement.return_value = { 'generatedFields': [], 'numberOfRecordsUpdated': 1, @@ -72,10 +73,12 @@ def test_execute_insert(mocked_client, mocker) -> None: secretArn='dummy', sql="insert into pets values(1, 'cat')", database='test', + transactionId='abc', ) def test_execute_insert_parameters(mocked_client, mocker) -> None: + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} mocked_client.execute_statement.return_value = { 'generatedFields': [], 'numberOfRecordsUpdated': 1, @@ -99,10 +102,12 @@ def test_execute_insert_parameters(mocked_client, mocker) -> None: secretArn='dummy', sql="insert into pets values(:id, :name)", database='test', + transactionId='abc', ) def test_execute_select(mocked_client, mocker) -> None: + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} mocked_client.execute_statement.return_value = { 'numberOfRecordsUpdated': 0, 'records': [[{'longValue': 1}, {'stringValue': 'cat'}]], @@ -121,6 +126,7 @@ def test_execute_select(mocked_client, mocker) -> None: resourceArn='dummy', secretArn='dummy', sql='select * from pets', + transactionId='abc', ) data_api.close() @@ -128,6 +134,7 @@ def test_execute_select(mocked_client, mocker) -> None: def test_execute_select_fetch_many(mocked_client, mocker) -> None: + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} mocked_client.execute_statement.return_value = { 'numberOfRecordsUpdated': 0, 'records': [ @@ -150,6 +157,7 @@ def test_execute_select_fetch_many(mocked_client, mocker) -> None: resourceArn='dummy', secretArn='dummy', sql='select * from pets', + transactionId='abc', ) data_api.close() @@ -157,6 +165,7 @@ def test_execute_select_fetch_many(mocked_client, mocker) -> None: def test_execute_select_iter(mocked_client, mocker) -> None: + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} mocked_client.execute_statement.return_value = { 'numberOfRecordsUpdated': 0, 'records': [ @@ -180,6 +189,7 @@ def test_execute_select_iter(mocked_client, mocker) -> None: resourceArn='dummy', secretArn='dummy', sql='select * from pets', + transactionId='abc', ) data_api.close() @@ -187,6 +197,7 @@ def test_execute_select_iter(mocked_client, mocker) -> None: def test_execute_insert_parameter_set(mocked_client, mocker) -> None: + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} mocked_client.batch_execute_statement.return_value = { 'updateResults': [ {'generatedFields': [{'longValue': 3}]}, @@ -220,6 +231,7 @@ def test_execute_insert_parameter_set(mocked_client, mocker) -> None: ], ], database='test', + transactionId='abc', ) diff --git a/tests/pydataapi/test_dialect.py b/tests/pydataapi/test_dialect.py index a43c608..3492fd7 100644 --- a/tests/pydataapi/test_dialect.py +++ b/tests/pydataapi/test_dialect.py @@ -10,6 +10,7 @@ def mocked_client(mocker): def test_mysql(mocked_client) -> None: from sqlalchemy.engine import create_engine + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} mocked_client.execute_statement.side_effect = [ {'records': [[{'stringValue': 'test plain returns'}]]}, {'records': [[{'stringValue': 'test unicode returns'}]]}, @@ -54,6 +55,7 @@ def test_mysql(mocked_client) -> None: ] engine = create_engine( 'mysql+pydataapi://', + echo=True, connect_args={ 'resource_arn': 'arn:aws:rds:us-east-1:123456789012:cluster:dummy', 'secret_arn': 'arn:aws:secretsmanager:us-east-1:123456789012:secret:dummy', @@ -69,6 +71,8 @@ def test_mysql(mocked_client) -> None: def test_postgresql(mocked_client) -> None: from sqlalchemy.engine import create_engine + mocked_client.begin_transaction.return_value = {'transactionId': 'abc'} + mocked_client.execute_statement.side_effect = [ {'records': [[{'stringValue': 'test plain returns'}]]}, {'records': [[{'stringValue': 'test unicode returns'}]]},