Skip to content

Commit 36ad6ea

Browse files
committed
python primitive type updates
Signed-off-by: Clemens Vasters <clemens@vasters.com>
1 parent d757f45 commit 36ad6ea

File tree

5 files changed

+456
-21
lines changed

5 files changed

+456
-21
lines changed

avrotize/avrotopython.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,25 @@ def convert_logical_type_to_python(self, avro_type: Dict, import_types: Set[str]
114114
"""Converts Avro logical type to Python type"""
115115
if avro_type['logicalType'] == 'decimal':
116116
import_types.add('decimal.Decimal')
117-
return 'Decimal'
117+
return 'decimal.Decimal'
118118
elif avro_type['logicalType'] == 'date':
119119
import_types.add('datetime.date')
120-
return 'date'
120+
return 'datetime.date'
121121
elif avro_type['logicalType'] == 'time-millis':
122122
import_types.add('datetime.time')
123-
return 'time'
123+
return 'datetime.time'
124124
elif avro_type['logicalType'] == 'time-micros':
125125
import_types.add('datetime.time')
126-
return 'time'
126+
return 'datetime.time'
127127
elif avro_type['logicalType'] == 'timestamp-millis':
128128
import_types.add('datetime.datetime')
129-
return 'datetime'
129+
return 'datetime.datetime'
130130
elif avro_type['logicalType'] == 'timestamp-micros':
131131
import_types.add('datetime.datetime')
132-
return 'datetime'
132+
return 'datetime.datetime'
133133
elif avro_type['logicalType'] == 'duration':
134134
import_types.add('datetime.timedelta')
135-
return 'timedelta'
135+
return 'datetime.timedelta'
136136
return 'typing.Any'
137137

138138
def convert_avro_type_to_python(self, avro_type: Union[str, Dict, List], parent_package: str, import_types: set) -> str:
@@ -180,9 +180,9 @@ def init_field_value(self, field_type: str, field_name: str, field_is_enum: bool
180180
""" Initialize the field value based on its type. """
181181
if field_type == "typing.Any":
182182
return field_ref
183-
elif field_type in ['datetime', 'date', 'time', 'timedelta']:
183+
elif field_type in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta']:
184184
return f"{field_ref}"
185-
elif field_type in ['int', 'str', 'float', 'bool', 'bytes', 'Decimal', 'datetime', 'date', 'time', 'timedelta']:
185+
elif field_type in ['int', 'str', 'float', 'bool', 'bytes', 'Decimal']:
186186
return f"{field_type}({field_ref})"
187187
elif field_type.startswith("typing.List["):
188188
inner_type = get_typing_args_from_string(field_type)[0]
@@ -373,11 +373,11 @@ def generate_value(field_type: str):
373373
'float': f'float({random.uniform(0, 100)})',
374374
'bytes': 'b"test_bytes"',
375375
'None': 'None',
376-
'date': random.choice(['datetime.date.today()', 'datetime.date(2021, 1, 1)']),
377-
'datetime': 'datetime.datetime.now()',
378-
'time': 'datetime.datetime.now().time()',
379-
'Decimal': f'Decimal("{random.randint(0, 100)}.{random.randint(0, 100)}")',
380-
'timedelta': 'datetime.timedelta(days=1)',
376+
'datetime.date': random.choice(['datetime.date.today()', 'datetime.date(2021, 1, 1)']),
377+
'datetime.datetime': 'datetime.datetime.now(datetime.timezone.utc)',
378+
'datetime.time': 'datetime.datetime.now(datetime.timezone.utc).time()',
379+
'decimal.Decimal': f'decimal.Decimal("{random.randint(0, 100)}.{random.randint(0, 100)}")',
380+
'datetime.timedelta': 'datetime.timedelta(days=1)',
381381
'typing.Any': '{"test": "test"}'
382382
}
383383

avrotize/avrotopython/dataclass_core.jinja

+16-3
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,24 @@ import dataclasses
1212
{%- if dataclasses_json_annotation %}
1313
import dataclasses_json
1414
import json
15+
{%- for field in fields if field.type == "datetime" or field.type == "typing.Optional[datetime.datetime]" %}
16+
{%- if loop.first %}
17+
from marshmallow import fields
18+
{%- endif %}
19+
{%- endfor %}
1520
{%- endif %}
1621
{%- if avro_annotation %}
1722
import avro.schema
1823
import avro.io
1924
{%- endif %}
20-
{%- for import_type in import_types %}
25+
{%- for import_type in import_types if import_type not in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta', 'decimal.Decimal'] %}
2126
from {{ '.'.join(import_type.split('.')[:-1]) | lower }} import {{ import_type.split('.')[-1] }}
2227
{%- endfor %}
28+
{%- for import_type in import_types if import_type in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta'] %}
29+
{%- if loop.first %}
30+
import datetime
31+
{%- endif %}
32+
{%- endfor %}
2333

2434
{% if dataclasses_json_annotation %}
2535
@dataclasses_json.dataclass_json
@@ -34,8 +44,9 @@ class {{ class_name }}:
3444
{%- endfor -%}
3545
"""
3646
{% for field in fields %}
37-
{{ field.name }}: {{ field.type }}=dataclasses.field(kw_only=True{% if dataclasses_json_annotation %}, metadata=dataclasses_json.config(field_name="{{ field.original_name }}"){%- endif %})
38-
{%- endfor %}
47+
{%- set isdate = field.type == "datetime" or field.type == "typing.Optional[datetime.datetime]" %}
48+
{{ field.name }}: {{ field.type }}=dataclasses.field(kw_only=True{% if dataclasses_json_annotation %}, metadata=dataclasses_json.config(field_name="{{ field.original_name }}"{%- if isdate -%}, encoder=lambda d: datetime.datetime.isoformat(d) if d else None, decoder=lambda d:datetime.datetime.fromisoformat(d) if d else None, mm_field=fields.DateTime(format='iso'){%- endif -%}){%- endif %})
49+
{%- endfor %}
3950
{% if avro_annotation %}
4051
AvroType: typing.ClassVar[avro.schema.Schema] = avro.schema.parse(
4152
"{{ avro_schema_json }}"
@@ -128,7 +139,9 @@ class {{ class_name }}:
128139

129140
{%- if dataclasses_json_annotation %}
130141
if content_type == 'application/json':
142+
#pylint: disable=no-member
131143
result = self.to_json()
144+
#pylint: enable=no-member
132145
{%- endif %}
133146

134147
if result is not None and content_type.endswith('+gzip'):

avrotize/avrotopython/test_class.jinja

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), '../src
1010

1111
from {{ package_name | lower }} import {{ class_name }}
1212

13-
{%- for import_type in import_types %}
13+
{%- for import_type in import_types if import_type not in ['decimal.Decimal', 'datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta'] %}
1414
{%- set import_type_name = 'Test_'+import_type.split('.')[-1] %}
1515
{%- set import_package_name = 'test_'+'_'.join(import_type.split('.')[:-1]) | lower %}
1616

@@ -20,6 +20,12 @@ from .{{ import_package_name }} import {{ import_type_name }}
2020
from {{ import_package_name }} import {{ import_type_name }}
2121
{%- endif -%}
2222
{%- endfor %}
23+
{%- for import_type in import_types if import_type in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta'] %}
24+
{%- if loop.first %}
25+
import datetime
26+
{%- endif %}
27+
{%- endfor %}
28+
2329

2430
class {{ test_class_name }}(unittest.TestCase):
2531
"""

0 commit comments

Comments
 (0)