diff --git a/docs/contributing.md b/docs/contributing.md index ec1d475..6ffd5e8 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -18,7 +18,8 @@ $ python3 -m pip install ".[all]" $ git checkout -b new-branch ## 5. Run unittest (you should pass all test and coverage should be 100%) -$ ./scripts/test.sh +$ ./scripts/unittest.sh +$ ./scripts/integration_test.sh ## 6. Format code $ ./scripts/format.sh diff --git a/example.py b/example.py index 7a69ce7..7a665ab 100644 --- a/example.py +++ b/example.py @@ -118,7 +118,7 @@ def example_rollback_with_custom_exception(): class OriginalError(Exception): pass - with DataAPI(resource_arn=resource_arn, secret_arn=secret_arn, rollback_exception=rollback_exception=OriginalError) as data_api: + with DataAPI(resource_arn=resource_arn, secret_arn=secret_arn, rollback_exception=OriginalError) as data_api: data_api.execute(Insert(Pets, {'name': 'dog'})) # some logic ... diff --git a/pydataapi/pydataapi.py b/pydataapi/pydataapi.py index 506a9ec..8d7d7be 100644 --- a/pydataapi/pydataapi.py +++ b/pydataapi/pydataapi.py @@ -1,4 +1,6 @@ from contextlib import AbstractContextManager +from datetime import date, datetime, time +from decimal import Decimal from functools import wraps from typing import ( Any, @@ -47,6 +49,11 @@ DOUBLE_VALUES: str = 'doubleValues' BLOB_VALUES: str = 'blobValues' +DECIMAL_TYPE_HINT: str = 'DECIMAL' +TIMESTAMP_TYPE_HINT: str = 'TIMESTAMP' +TIME_TYPE_HINT: str = 'TIME' +DATE_TYPE_HINT: str = 'DATE' + def generate_sql(query: Union[Query, Insert, Update, Delete, Select]) -> str: if hasattr(query, 'statement'): @@ -138,33 +145,51 @@ def convert_array_value(value: Union[List, Tuple]) -> Dict[str, Any]: raise Exception(f'unsupported array type {type(value[0])}]: {value} ') -def convert_value(value: Any) -> Dict[str, Any]: +def create_sql_parameter(key: str, value: Any) -> Dict[str, Any]: + converted_value: Dict[str, Any] + type_hint: Optional[str] = None + if isinstance(value, bool): - return {BOOLEAN_VALUE: value} + converted_value = {BOOLEAN_VALUE: value} elif isinstance(value, str): - return {STRING_VALUE: value} + converted_value = {STRING_VALUE: value} elif isinstance(value, int): - return {LONG_VALUE: value} + converted_value = {LONG_VALUE: value} elif isinstance(value, float): - return {DOUBLE_VALUE: value} + converted_value = {DOUBLE_VALUE: value} elif isinstance(value, bytes): - return {BLOB_VALUE: value} + converted_value = {BLOB_VALUE: value} elif value is None: - return {IS_NULL: True} + converted_value = {IS_NULL: True} elif isinstance(value, (list, tuple)): - if not value: - return {IS_NULL: True} - return convert_array_value(value) - # TODO: support structValue - return {STRING_VALUE: str(value)} + if value: + converted_value = convert_array_value(value) + else: + converted_value = {IS_NULL: True} + elif isinstance(value, Decimal): + converted_value = {STRING_VALUE: str(value)} + type_hint = DECIMAL_TYPE_HINT + elif isinstance(value, datetime): + converted_value = {STRING_VALUE: value.strftime('%Y-%m-%d %H:%M:%S.%f')[:23]} + type_hint = TIMESTAMP_TYPE_HINT + elif isinstance(value, time): + converted_value = {STRING_VALUE: value.strftime('%H:%M:%S.%f')[:12]} + type_hint = TIME_TYPE_HINT + elif isinstance(value, date): + converted_value = {STRING_VALUE: value.strftime('%Y-%m-%d')} + type_hint = DATE_TYPE_HINT + else: + # TODO: support structValue + converted_value = {STRING_VALUE: str(value)} + if type_hint: + return {'name': key, 'value': converted_value, 'typeHint': type_hint} + return {'name': key, 'value': converted_value} def create_sql_parameters( parameter: Dict[str, Any] ) -> List[Dict[str, Union[str, Dict]]]: - return [ - {'name': key, 'value': convert_value(value)} for key, value in parameter.items() - ] + return [create_sql_parameter(key, value) for key, value in parameter.items()] def _get_value_from_row(row: Dict[str, Any]) -> Any: diff --git a/setup.cfg b/setup.cfg index 6bcfa1d..f1d35b3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ setup_requires = pytest-runner setuptools-scm install_requires = - boto3 == 1.11.15 + boto3 == 1.12.7 SQLAlchemy == 1.3.13 pydantic == 1.4 more-itertools == 8.0.2 diff --git a/tests/integration/test_mysql.py b/tests/integration/test_mysql.py index e196595..46f2a0a 100644 --- a/tests/integration/test_mysql.py +++ b/tests/integration/test_mysql.py @@ -1,11 +1,12 @@ import time +from datetime import datetime from typing import List import boto3 import pytest from pydataapi import DataAPI, Result, transaction from pydataapi.pydataapi import Record -from sqlalchemy import Column, Integer, String, create_engine +from sqlalchemy import Column, DateTime, Integer, String, create_engine from sqlalchemy.engine import Connection from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Query, sessionmaker @@ -18,6 +19,7 @@ class Pets(declarative_base()): __tablename__ = 'pets' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(255, collation='utf8_unicode_ci'), default=None) + seen_at = Column(DateTime, default=None) database: str = 'test' @@ -56,7 +58,7 @@ def db_connection(module_scoped_container_getter) -> Connection: def create_table(db_connection) -> None: db_connection.execute('drop table if exists pets;') db_connection.execute( - 'create table pets (id int auto_increment not null primary key, name varchar(10));' + 'create table pets (id int auto_increment not null primary key, name varchar(10), seen_at TIMESTAMP null);' ) @@ -117,32 +119,33 @@ def test_with_statement(rds_data_client, db_connection): query = Query(Pets).filter(Pets.id == 1) result = data_api.execute(query) - assert list(result) == [Record([1, 'dog'], [])] + assert list(result) == [Record([1, 'dog', None], [])] result = data_api.execute('select * from pets') - assert result.one().dict() == {'id': 1, 'name': 'dog'} - - insert: Insert = Insert(Pets) - data_api.batch_execute( - insert, - [ - {'id': 2, 'name': 'cat'}, - {'id': 3, 'name': 'snake'}, - {'id': 4, 'name': 'rabbit'}, - ], - ) - - result = data_api.execute('select * from pets') - expected = [ - Record([1, 'dog'], ['id', 'name']), - Record([2, 'cat'], ['id', 'name']), - Record([3, 'snake'], ['id', 'name']), - Record([4, 'rabbit'], ['id', 'name']), - ] - assert list(result) == expected - - for row, expected_row in zip(result, expected): - assert row == expected_row + assert result.one().dict() == {'id': 1, 'name': 'dog', 'seen_at': None} + + # This is deprecated. SQL Alchemy object will be no longer supported + # insert: Insert = Insert(Pets) + # data_api.batch_execute( + # insert, + # [ + # {'id': 2, 'name': 'cat', 'seen_at': None}, + # {'id': 3, 'name': 'snake', 'seen_at': None}, + # {'id': 4, 'name': 'rabbit', 'seen_at': None}, + # ], + # ) + # + # result = data_api.execute('select * from pets') + # expected = [ + # Record([1, 'dog', None], ['id', 'name', 'seen_at']), + # Record([2, 'cat', None], ['id', 'name', 'seen_at']), + # Record([3, 'snake', None], ['id', 'name', 'seen_at']), + # Record([4, 'rabbit', None], ['id', 'name', 'seen_at']), + # ] + # assert list(result) == expected + # + # for row, expected_row in zip(result, expected): + # assert row == expected_row def test_rollback(rds_data_client, db_connection): @@ -199,10 +202,10 @@ class OtherError(Exception): except: pass result = list(get_connection().execute('select * from pets')) - assert result == [(2, 'dog')] + assert result == [(2, 'dog', None)] -def test_dialect() -> None: +def test_dialect(create_table) -> None: rds_data_client = boto3.client( 'rds-data', endpoint_url='http://127.0.0.1:8080', @@ -222,3 +225,19 @@ def test_dialect() -> None: assert engine.has_table('foo') is False assert engine.has_table('pets') is True + + Session = sessionmaker() + Session.configure(bind=engine) + session = Session() + + dog = Pets(name="dog", seen_at=datetime(2020, 1, 2, 3, 4, 5, 6789)) + + session.add(dog) + session.commit() + + result = list(engine.execute('select * from pets')) + assert result[0] == ( + 1, + 'dog', + '2020-01-02 03:04:05', + ) # TODO Update local-data-api to support typeHint diff --git a/tests/pydataapi/test_pydataapi.py b/tests/pydataapi/test_pydataapi.py index 5dd2c39..3ea0c49 100644 --- a/tests/pydataapi/test_pydataapi.py +++ b/tests/pydataapi/test_pydataapi.py @@ -1,4 +1,5 @@ import datetime +from decimal import Decimal from typing import Any, Dict import pytest @@ -13,7 +14,7 @@ UpdateResults, _get_value_from_row, convert_array_value, - convert_value, + create_sql_parameter, create_sql_parameters, generate_sql, transaction, @@ -60,8 +61,8 @@ def mocked_client(mocker): ([b'bytes', b'blob'], {'arrayValue': {'blobValues': [b'bytes', b'blob']}}), ], ) -def test_convert_value(input_value: Any, expected: Dict[str, Any]) -> None: - assert convert_value(input_value) == expected +def test_create_sql_parameter(input_value: Any, expected: Dict[str, Any]) -> None: + assert create_sql_parameter('', input_value)['value'] == expected def test_convert_value_other_types() -> None: @@ -69,10 +70,32 @@ class Dummy: def __str__(self): return 'Dummy' - assert convert_value(Dummy()) == {'stringValue': 'Dummy'} + assert create_sql_parameter('', Dummy())['value'] == {'stringValue': 'Dummy'} - assert convert_value(datetime.datetime(2020, 1, 1)) == { - 'stringValue': '2020-01-01 00:00:00' + assert create_sql_parameter('decimal', Decimal(123456789)) == { + 'name': 'decimal', + 'typeHint': 'DECIMAL', + 'value': {'stringValue': '123456789'}, + } + + assert create_sql_parameter( + 'datetime', datetime.datetime(2020, 1, 2, 3, 4, 5, 678900) + ) == { + 'name': 'datetime', + 'typeHint': 'TIMESTAMP', + 'value': {'stringValue': '2020-01-02 03:04:05.678'}, + } + + assert create_sql_parameter('date', datetime.date(2020, 1, 2)) == { + 'name': 'date', + 'typeHint': 'DATE', + 'value': {'stringValue': '2020-01-02'}, + } + + assert create_sql_parameter('time', datetime.time(3, 4, 5, 678900)) == { + 'name': 'time', + 'typeHint': 'TIME', + 'value': {'stringValue': '03:04:05.678'}, }