Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support typehint #57

Merged
merged 2 commits into from
Feb 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ $ python3 -m pip install ".[all]"
$ git checkout -b new-branch

## 5. Run unittest (you should pass all test and coverage should be 100%)
$ ./scripts/test.sh
$ ./scripts/unittest.sh
$ ./scripts/integration_test.sh

## 6. Format code
$ ./scripts/format.sh
Expand Down
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def example_rollback_with_custom_exception():
class OriginalError(Exception):
pass

with DataAPI(resource_arn=resource_arn, secret_arn=secret_arn, rollback_exception=rollback_exception=OriginalError) as data_api:
with DataAPI(resource_arn=resource_arn, secret_arn=secret_arn, rollback_exception=OriginalError) as data_api:
data_api.execute(Insert(Pets, {'name': 'dog'}))
# some logic ...

Expand Down
55 changes: 40 additions & 15 deletions pydataapi/pydataapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from contextlib import AbstractContextManager
from datetime import date, datetime, time
from decimal import Decimal
from functools import wraps
from typing import (
Any,
Expand Down Expand Up @@ -47,6 +49,11 @@
DOUBLE_VALUES: str = 'doubleValues'
BLOB_VALUES: str = 'blobValues'

DECIMAL_TYPE_HINT: str = 'DECIMAL'
TIMESTAMP_TYPE_HINT: str = 'TIMESTAMP'
TIME_TYPE_HINT: str = 'TIME'
DATE_TYPE_HINT: str = 'DATE'


def generate_sql(query: Union[Query, Insert, Update, Delete, Select]) -> str:
if hasattr(query, 'statement'):
Expand Down Expand Up @@ -138,33 +145,51 @@ def convert_array_value(value: Union[List, Tuple]) -> Dict[str, Any]:
raise Exception(f'unsupported array type {type(value[0])}]: {value} ')


def convert_value(value: Any) -> Dict[str, Any]:
def create_sql_parameter(key: str, value: Any) -> Dict[str, Any]:
converted_value: Dict[str, Any]
type_hint: Optional[str] = None

if isinstance(value, bool):
return {BOOLEAN_VALUE: value}
converted_value = {BOOLEAN_VALUE: value}
elif isinstance(value, str):
return {STRING_VALUE: value}
converted_value = {STRING_VALUE: value}
elif isinstance(value, int):
return {LONG_VALUE: value}
converted_value = {LONG_VALUE: value}
elif isinstance(value, float):
return {DOUBLE_VALUE: value}
converted_value = {DOUBLE_VALUE: value}
elif isinstance(value, bytes):
return {BLOB_VALUE: value}
converted_value = {BLOB_VALUE: value}
elif value is None:
return {IS_NULL: True}
converted_value = {IS_NULL: True}
elif isinstance(value, (list, tuple)):
if not value:
return {IS_NULL: True}
return convert_array_value(value)
# TODO: support structValue
return {STRING_VALUE: str(value)}
if value:
converted_value = convert_array_value(value)
else:
converted_value = {IS_NULL: True}
elif isinstance(value, Decimal):
converted_value = {STRING_VALUE: str(value)}
type_hint = DECIMAL_TYPE_HINT
elif isinstance(value, datetime):
converted_value = {STRING_VALUE: value.strftime('%Y-%m-%d %H:%M:%S.%f')[:23]}
type_hint = TIMESTAMP_TYPE_HINT
elif isinstance(value, time):
converted_value = {STRING_VALUE: value.strftime('%H:%M:%S.%f')[:12]}
type_hint = TIME_TYPE_HINT
elif isinstance(value, date):
converted_value = {STRING_VALUE: value.strftime('%Y-%m-%d')}
type_hint = DATE_TYPE_HINT
else:
# TODO: support structValue
converted_value = {STRING_VALUE: str(value)}
if type_hint:
return {'name': key, 'value': converted_value, 'typeHint': type_hint}
return {'name': key, 'value': converted_value}


def create_sql_parameters(
parameter: Dict[str, Any]
) -> List[Dict[str, Union[str, Dict]]]:
return [
{'name': key, 'value': convert_value(value)} for key, value in parameter.items()
]
return [create_sql_parameter(key, value) for key, value in parameter.items()]


def _get_value_from_row(row: Dict[str, Any]) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ setup_requires =
pytest-runner
setuptools-scm
install_requires =
boto3 == 1.11.15
boto3 == 1.12.7
SQLAlchemy == 1.3.13
pydantic == 1.4
more-itertools == 8.0.2
Expand Down
75 changes: 47 additions & 28 deletions tests/integration/test_mysql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import time
from datetime import datetime
from typing import List

import boto3
import pytest
from pydataapi import DataAPI, Result, transaction
from pydataapi.pydataapi import Record
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy import Column, DateTime, Integer, String, create_engine
from sqlalchemy.engine import Connection
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Query, sessionmaker
Expand All @@ -18,6 +19,7 @@ class Pets(declarative_base()):
__tablename__ = 'pets'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(255, collation='utf8_unicode_ci'), default=None)
seen_at = Column(DateTime, default=None)


database: str = 'test'
Expand Down Expand Up @@ -56,7 +58,7 @@ def db_connection(module_scoped_container_getter) -> Connection:
def create_table(db_connection) -> None:
db_connection.execute('drop table if exists pets;')
db_connection.execute(
'create table pets (id int auto_increment not null primary key, name varchar(10));'
'create table pets (id int auto_increment not null primary key, name varchar(10), seen_at TIMESTAMP null);'
)


Expand Down Expand Up @@ -117,32 +119,33 @@ def test_with_statement(rds_data_client, db_connection):
query = Query(Pets).filter(Pets.id == 1)
result = data_api.execute(query)

assert list(result) == [Record([1, 'dog'], [])]
assert list(result) == [Record([1, 'dog', None], [])]

result = data_api.execute('select * from pets')
assert result.one().dict() == {'id': 1, 'name': 'dog'}

insert: Insert = Insert(Pets)
data_api.batch_execute(
insert,
[
{'id': 2, 'name': 'cat'},
{'id': 3, 'name': 'snake'},
{'id': 4, 'name': 'rabbit'},
],
)

result = data_api.execute('select * from pets')
expected = [
Record([1, 'dog'], ['id', 'name']),
Record([2, 'cat'], ['id', 'name']),
Record([3, 'snake'], ['id', 'name']),
Record([4, 'rabbit'], ['id', 'name']),
]
assert list(result) == expected

for row, expected_row in zip(result, expected):
assert row == expected_row
assert result.one().dict() == {'id': 1, 'name': 'dog', 'seen_at': None}

# This is deprecated. SQL Alchemy object will be no longer supported
# insert: Insert = Insert(Pets)
# data_api.batch_execute(
# insert,
# [
# {'id': 2, 'name': 'cat', 'seen_at': None},
# {'id': 3, 'name': 'snake', 'seen_at': None},
# {'id': 4, 'name': 'rabbit', 'seen_at': None},
# ],
# )
#
# result = data_api.execute('select * from pets')
# expected = [
# Record([1, 'dog', None], ['id', 'name', 'seen_at']),
# Record([2, 'cat', None], ['id', 'name', 'seen_at']),
# Record([3, 'snake', None], ['id', 'name', 'seen_at']),
# Record([4, 'rabbit', None], ['id', 'name', 'seen_at']),
# ]
# assert list(result) == expected
#
# for row, expected_row in zip(result, expected):
# assert row == expected_row


def test_rollback(rds_data_client, db_connection):
Expand Down Expand Up @@ -199,10 +202,10 @@ class OtherError(Exception):
except:
pass
result = list(get_connection().execute('select * from pets'))
assert result == [(2, 'dog')]
assert result == [(2, 'dog', None)]


def test_dialect() -> None:
def test_dialect(create_table) -> None:
rds_data_client = boto3.client(
'rds-data',
endpoint_url='http://127.0.0.1:8080',
Expand All @@ -222,3 +225,19 @@ def test_dialect() -> None:

assert engine.has_table('foo') is False
assert engine.has_table('pets') is True

Session = sessionmaker()
Session.configure(bind=engine)
session = Session()

dog = Pets(name="dog", seen_at=datetime(2020, 1, 2, 3, 4, 5, 6789))

session.add(dog)
session.commit()

result = list(engine.execute('select * from pets'))
assert result[0] == (
1,
'dog',
'2020-01-02 03:04:05',
) # TODO Update local-data-api to support typeHint
35 changes: 29 additions & 6 deletions tests/pydataapi/test_pydataapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from decimal import Decimal
from typing import Any, Dict

import pytest
Expand All @@ -13,7 +14,7 @@
UpdateResults,
_get_value_from_row,
convert_array_value,
convert_value,
create_sql_parameter,
create_sql_parameters,
generate_sql,
transaction,
Expand Down Expand Up @@ -60,19 +61,41 @@ def mocked_client(mocker):
([b'bytes', b'blob'], {'arrayValue': {'blobValues': [b'bytes', b'blob']}}),
],
)
def test_convert_value(input_value: Any, expected: Dict[str, Any]) -> None:
assert convert_value(input_value) == expected
def test_create_sql_parameter(input_value: Any, expected: Dict[str, Any]) -> None:
assert create_sql_parameter('', input_value)['value'] == expected


def test_convert_value_other_types() -> None:
class Dummy:
def __str__(self):
return 'Dummy'

assert convert_value(Dummy()) == {'stringValue': 'Dummy'}
assert create_sql_parameter('', Dummy())['value'] == {'stringValue': 'Dummy'}

assert convert_value(datetime.datetime(2020, 1, 1)) == {
'stringValue': '2020-01-01 00:00:00'
assert create_sql_parameter('decimal', Decimal(123456789)) == {
'name': 'decimal',
'typeHint': 'DECIMAL',
'value': {'stringValue': '123456789'},
}

assert create_sql_parameter(
'datetime', datetime.datetime(2020, 1, 2, 3, 4, 5, 678900)
) == {
'name': 'datetime',
'typeHint': 'TIMESTAMP',
'value': {'stringValue': '2020-01-02 03:04:05.678'},
}

assert create_sql_parameter('date', datetime.date(2020, 1, 2)) == {
'name': 'date',
'typeHint': 'DATE',
'value': {'stringValue': '2020-01-02'},
}

assert create_sql_parameter('time', datetime.time(3, 4, 5, 678900)) == {
'name': 'time',
'typeHint': 'TIME',
'value': {'stringValue': '03:04:05.678'},
}


Expand Down