Skip to content

Commit 5f9e854

Browse files
committed
fix typedjson annotations
Signed-off-by: Clemens Vasters <clemens@vasters.com>
1 parent db30db8 commit 5f9e854

File tree

2 files changed

+49
-75
lines changed

2 files changed

+49
-75
lines changed

avrotize/avrotots.py

+40-71
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,9 @@ def convert_logical_type_to_typescript(self, avro_type: Dict) -> str:
6262

6363
def strip_nullable(self, ts_type: str) -> str:
6464
"""Strip nullable type from TypeScript type."""
65-
# Handle union types and strip nullable types
66-
types = [t.strip() for t in ts_type.split('|')]
67-
non_nullable_types = [t for t in types if t != 'null']
68-
return ' | '.join(non_nullable_types)
65+
if ts_type.endswith('?'):
66+
return ts_type[:-1]
67+
return ts_type
6968

7069
def is_typescript_primitive(self, ts_type: str) -> bool:
7170
"""Check if TypeScript type is a primitive."""
@@ -98,8 +97,8 @@ def convert_avro_type_to_typescript(self, avro_type: Union[str, Dict, List], par
9897
return '{ [key: string]: any }'
9998
if 'null' in avro_type:
10099
if len(avro_type) == 2:
101-
return f'{self.convert_avro_type_to_typescript([t for t in avro_type if t != "null"][0], parent_namespace, import_types, class_name, field_name)} | null'
102-
return f'{self.generate_embedded_union(class_name, field_name, avro_type, parent_namespace, import_types)} | null'
100+
return f'{self.convert_avro_type_to_typescript([t for t in avro_type if t != "null"][0], parent_namespace, import_types, class_name, field_name)}?'
101+
return f'{self.generate_embedded_union(class_name, field_name, avro_type, parent_namespace, import_types)}?'
103102
return self.generate_embedded_union(class_name, field_name, avro_type, parent_namespace, import_types)
104103
elif isinstance(avro_type, dict):
105104
if avro_type['type'] == 'record':
@@ -158,6 +157,7 @@ def generate_class(self, avro_schema: Dict, parent_namespace: str, write_file: b
158157
'type_no_null': self.strip_nullable(field['definition']['type']),
159158
'is_primitive': field['definition']['is_primitive'],
160159
'is_enum': field['definition']['is_enum'],
160+
'is_array': field['definition']['is_array'],
161161
'docstring': field['docstring'],
162162
} for field in fields]
163163

@@ -235,14 +235,15 @@ def generate_field(self, field: Dict, parent_namespace: str, import_types: Set[s
235235
'name': field_name,
236236
'type': field_type,
237237
'is_primitive': self.is_typescript_primitive(field_type),
238+
'is_array': field_type.endswith('[]'),
238239
'is_enum': len(import_types_this) > 0 and self.is_enum_type(import_types_this.pop(),'')
239240
}
240241

241242
def get_is_json_match_clause(self, field_name: str, field_type: str, field_is_enum: bool) -> str:
242243
"""Generates the isJsonMatch clause for a field."""
243244
field_name_js = field_name.rstrip('_')
244-
is_optional = field_type.endswith(' | null')
245-
field_type = field_type.replace(' | null', '').strip()
245+
is_optional = field_type.endswith('?')
246+
field_type = self.strip_nullable(field_type)
246247

247248
if '|' in field_type:
248249
union_types = [t.strip() for t in field_type.split('|')]
@@ -361,75 +362,43 @@ def generate_embedded_union(self, class_name: str, field_name: str, avro_type: L
361362
return f"{namespace}.{union_class_name}"
362363

363364
def write_to_file(self, namespace: str, name: str, content: str):
364-
"""Write TypeScript class to file in the correct namespace directory."""
365-
directory_path = os.path.join(self.src_dir, *namespace.split('.'))
366-
if not os.path.exists(directory_path):
367-
os.makedirs(directory_path, exist_ok=True)
365+
"""Write TypeScript class to file in the correct namespace directory."""
366+
directory_path = os.path.join(self.src_dir, *namespace.split('.'))
367+
if not os.path.exists(directory_path):
368+
os.makedirs(directory_path, exist_ok=True)
368369

369-
file_path = os.path.join(directory_path, f"{name}.ts")
370-
with open(file_path, 'w', encoding='utf-8') as file:
371-
file.write(content)
370+
file_path = os.path.join(directory_path, f"{name}.ts")
371+
with open(file_path, 'w', encoding='utf-8') as file:
372+
file.write(content)
372373

373374
def generate_index_file(self):
374-
"""Generate index.ts files for each directory in the project."""
375-
# Define the tree node class
376-
class DirNode:
377-
def __init__(self):
378-
self.files = set() # Set of file names (without extensions)
379-
self.subdirs = {} # Mapping from subdir names to DirNode instances
380-
381-
# Build the directory tree
382-
root = DirNode()
375+
"""Generate a root index.ts file that exports all types with aliases scoped to their modules."""
376+
exports = []
383377

384378
for class_name in self.generated_types:
379+
# Split the class_name into parts
385380
parts = class_name.split('.')
386-
current_node = root
387-
388-
for idx, part in enumerate(parts):
389-
is_last = idx == len(parts) - 1
390-
if is_last:
391-
current_node.files.add(part)
392-
else:
393-
if part not in current_node.subdirs:
394-
current_node.subdirs[part] = DirNode()
395-
current_node = current_node.subdirs[part]
396-
397-
# Function to generate index.ts files recursively
398-
def generate_index_files(node, path_parts):
399-
"""Recursively generate index.ts files."""
400-
dir_path = os.path.join(self.src_dir, *path_parts)
401-
if not os.path.exists(dir_path):
402-
os.makedirs(dir_path, exist_ok=True)
403-
404-
exports = []
405-
406-
# Export all files in the current directory
407-
for file_name in sorted(node.files):
408-
export_path = f"./{file_name}.js"
409-
exports.append(f"export * from '{export_path}';\n")
410-
411-
for subdir_name, subdir_node in node.subdirs.items():
412-
# Check for name conflicts with files
413-
if subdir_name in node.files:
414-
# There is a file and a directory with the same name
415-
# Decide whether to export the subdirectory; here we skip it to avoid conflicts
416-
continue
417-
418-
# Export the subdirectory's index.js
419-
#export_path = f"./{subdir_name}/index.js"
420-
#exports.append(f"export * from '{export_path}';\n")
421-
422-
# Recursively generate index.ts for the subdirectory
423-
generate_index_files(subdir_node, path_parts + [subdir_name])
424-
425-
# Write the index.ts file
426-
index_file_path = os.path.join(dir_path, 'index.ts')
427-
with open(index_file_path, 'w', encoding='utf-8') as f:
428-
f.writelines(exports)
429-
430-
# Start the recursive generation from the root
431-
generate_index_files(root, [])
432-
381+
file_name = parts[-1] # The actual type name (e.g., 'FareRules')
382+
module_path = parts[:-1] # The module path excluding the type (e.g., ['gtfs_dash_data', 'GeneralTransitFeedStatic'])
383+
384+
# Construct the relative path to the .js file
385+
# Exclude 'gtfs_dash_data' from the module path for the file path
386+
file_relative_path = os.path.join(*(module_path[0:] + [f"{file_name}.js"])).replace(os.sep, '/')
387+
if not file_relative_path.startswith('.'):
388+
file_relative_path = './' + file_relative_path
389+
390+
# Construct the alias name by joining module parts with underscores
391+
# Exclude 'gtfs_dash_data' for brevity
392+
alias_parts = [pascal(part) for part in parts]
393+
alias_name = '_'.join(alias_parts)
394+
395+
# Generate the export statement with alias
396+
exports.append(f"export {{ {file_name} as {alias_name} }} from '{file_relative_path}';\n")
397+
398+
# Write the root index.ts file
399+
index_file_path = os.path.join(self.src_dir, 'index.ts')
400+
with open(index_file_path, 'w', encoding='utf-8') as f:
401+
f.writelines(exports)
433402

434403
def generate_project_files(self, output_dir: str):
435404
"""Generate project files using templates."""

avrotize/avrotots/class_core.ts.jinja

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
/** {{ class_name }} class. */
22
{%- if typed_json_annotation %}
33
import 'reflect-metadata';
4+
{%- if fields | selectattr("is_array") | list | length > 0 %}
5+
import { jsonObject, jsonMember, jsonArrayMember, TypedJSON } from 'typedjson';
6+
{%- else %}
47
import { jsonObject, jsonMember, TypedJSON } from 'typedjson';
58
{%- endif %}
9+
{%- endif %}
610
{%- if avro_annotation %}
711
import { Type } from 'avro-js';
812
{%- endif %}
@@ -22,12 +26,13 @@ export class {{ class_name }} {
2226
{%- for field in fields %}
2327
/** {{ field.docstring }} */
2428
{%- if typed_json_annotation %}
25-
@jsonMember
29+
{%- set field_type = field.type_no_null if not field.is_primitive else (field.type_no_null | pascal ) %}
30+
{% if field.is_array -%}@jsonArrayMember({{ field_type | replace('[]', '') }}) {%- else -%}@jsonMember({%-if not field.is_enum-%}{{ field_type }}{%-else-%}String{%-endif-%}){%- endif %}
2631
{%- endif %}
27-
public {{ field.name }}: {{ field.type }};
32+
public {{ field.name }}{%- if field.type.endswith('?')-%}?{%-endif-%} : {{ field.type_no_null }};
2833
{%- endfor %}
2934

30-
constructor({%- for field in fields %}{{ field.name }}: {{ field.type }}{%- if not loop.last %}, {%- endif %}{%- endfor %}) {
35+
constructor({%- for field in fields %}{{ field.name }}: {{ field.type_no_null }}{%- if not loop.last %}, {%- endif %}{%- endfor %}) {
3136
{%- for field in fields %}
3237
{%- if field.is_enum %}
3338
if ( typeof {{ field.name }} === 'number' ) {
@@ -104,7 +109,7 @@ export class {{ class_name }} {
104109
{%- else %}
105110
return (
106111
{%- for field in fields %}
107-
{{ get_is_json_match_clause(field.name, field.type, field.is_enum) }}{%- if not loop.last %} &&
112+
{{ get_is_json_match_clause(field.name, field.type_no_null, field.is_enum) }}{%- if not loop.last %} &&
108113
{%- endif %}
109114
{%- endfor %}
110115
);

0 commit comments

Comments
 (0)