Skip to content

Commit

Permalink
fix default_paramstyle (#27)
Browse files Browse the repository at this point in the history
* fix default_paramstyle

* fix validator

* support lastrowid

* support null value

* fix type hint
  • Loading branch information
koxudaxi authored Oct 16, 2019
1 parent b2ea173 commit 210c1e9
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
9 changes: 9 additions & 0 deletions pydataapi/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,16 @@ def __init__(self, data_api: DataAPI) -> None:

self._rows: List[List] = []
self._rowcount: int = -1
self._lastrowid: Optional[int] = None

@property
def rowcount(self) -> int:
return self._rowcount

@property
def lastrowid(self) -> Optional[int]:
return self._lastrowid

def close(self) -> None:
self.closed = True

Expand All @@ -131,6 +136,7 @@ def execute(
rows: List[List] = getattr(result, '_rows')
self._rows = rows
self._rowcount = len(rows) or result.number_of_records_updated
self._lastrowid = result.generated_fields_first # type: ignore
return self

def executemany(
Expand All @@ -141,6 +147,9 @@ def executemany(
self._rows = [result.generated_fields for result in results]
self._rowcount = len(self._rows)
self.description = []
self._lastrowid = ( # type: ignore
results[-1].generated_fields_first if results else None # type: ignore
)
return self

def fetchone(self) -> Optional[List]:
Expand Down
13 changes: 2 additions & 11 deletions pydataapi/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class DataAPIDialect(DefaultDialect, ABC):

supports_comments = True
inline_comments = True
default_paramstyle = "named"

cte_follows_insert = True

Expand Down Expand Up @@ -168,21 +167,13 @@ def _detect_charset(self, connection: Any) -> Any: # pragma: no cover
pass

name = "mysql"
statement_compiler = MySQLCompiler
ddl_compiler = MySQLDDLCompiler
type_compiler = MySQLTypeCompiler

preparer = MySQLIdentifierPreparer
default_paramstyle = "named"


class PostgreSQLDataAPIDialect(PGDialect, DataAPIDialect):
name = "postgresql"
default_paramstyle = "named"
supports_alter = True
max_identifier_length = 63
supports_sane_rowcount = True
statement_compiler = PGCompiler
ddl_compiler = PGDDLCompiler
type_compiler = PGTypeCompiler
preparer = PGIdentifierPreparer
inspector = PGInspector
isolation_level = None
11 changes: 8 additions & 3 deletions pydataapi/pydataapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,13 @@ def __len__(self) -> int:

def __init__(self, response: Dict):
self._response = response
self._rows: Sequence[List[Dict]] = [
[tuple(column.values())[0] for column in row]
self._rows: Sequence[List] = [
[
None
if tuple(column.keys())[0] == 'isNull'
else tuple(column.values())[0]
for column in row
]
for row in response.get('records', []) # type: ignore
]
self._column_metadata: List[Dict[str, Any]] = response.get('columnMetadata', [])
Expand Down Expand Up @@ -251,7 +256,7 @@ def convert_parameters(cls, v: Any) -> Any:

@validator('parameterSets', pre=True)
def convert_parameter_sets(cls, v: Any) -> Any:
if isinstance(v, list):
if isinstance(v, (list, tuple)):
return [create_sql_parameters(parameter) for parameter in v]
return v

Expand Down
4 changes: 3 additions & 1 deletion tests/pydataapi/test_dbaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ def test_rollback_not_called(mocked_client) -> None:
def test_execute_insert(mocked_client, mocker) -> None:
mocked_client.begin_transaction.return_value = {'transactionId': 'abc'}
mocked_client.execute_statement.return_value = {
'generatedFields': [],
'generatedFields': [{'longValue': 3}],
'numberOfRecordsUpdated': 1,
}
data_api = connect(
resource_arn='dummy', secret_arn='dummy', database='test', client=mocked_client
)
results = data_api.execute("insert into pets values(1, 'cat')")
assert list(results.fetchall()) == []
assert results.lastrowid == 3
assert mocked_client.execute_statement.call_args == mocker.call(
continueAfterTimeout=True,
includeResultMetadata=True,
Expand Down Expand Up @@ -215,6 +216,7 @@ def test_execute_insert_parameter_set(mocked_client, mocker) -> None:
rows = results.fetchall()
assert len(rows) == 2
assert rows == [[3], [4]]
assert results.lastrowid == 4

assert mocked_client.batch_execute_statement.call_args == mocker.call(
resourceArn='dummy',
Expand Down

0 comments on commit 210c1e9

Please sign in to comment.