Skip to content

Commit 42428eb

Browse files
committed
adding json-schema generator and python code generator, support typing.Self
1 parent a07ae93 commit 42428eb

22 files changed

+1173
-149
lines changed

tests/test_cls.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import utype
1111
from utype import (DataClass, Field, Options, Rule, Schema, exc,
1212
register_transformer, types)
13-
from utype.utils.compat import Final
13+
from utype.utils.compat import Final, Self
1414

1515

1616
@pytest.fixture(params=(False, True))
@@ -319,19 +319,28 @@ class T(Schema):
319319
T(forward_in_dict={1: [2], 2: [1]})
320320

321321
# test not-module-level self ref
322-
class Self(Schema):
322+
class SelfRef(Schema):
323323
name: str
324-
to_self: "Self" = Field(required=False)
325-
self_lst: List["Self"] = Field(default_factory=list)
324+
to_self: "SelfRef" = Field(required=False)
325+
self_lst: List["SelfRef"] = Field(default_factory=list)
326326

327-
sf = Self(name=1, to_self=b'{"name":"test"}')
327+
sf = SelfRef(name=1, to_self=b'{"name":"test"}')
328328
assert sf.to_self.name == "test"
329329
assert sf.self_lst == []
330330

331-
sf2 = Self(name="t2", self_lst=[dict(sf)])
331+
sf2 = SelfRef(name="t2", self_lst=[dict(sf)])
332332
assert sf2.self_lst[0].name == "1"
333333
assert "to_self" not in sf2
334334

335+
class SelfRef2(Schema):
336+
name: str
337+
to_self: Self = Field(required=False)
338+
self_lst: List[Self] = Field(default_factory=list)
339+
340+
sfi = SelfRef2(name=1, to_self=b'{"name":"test"}')
341+
assert sfi.to_self.name == "test"
342+
assert sfi.self_lst == []
343+
335344
# class ForwardSchema(Schema):
336345
# int1: 'types.PositiveInt' = Field(lt=10)
337346
# int2: 'types.PositiveInt' = Field(lt=20)
@@ -340,11 +349,11 @@ class Self(Schema):
340349

341350
def test_local_forward_ref(self):
342351
def f(u=0):
343-
class Self(Schema):
352+
class LocSelf(Schema):
344353
num: int = u
345-
to_self: Optional["Self"] = None
346-
list_self: List["Self"] = utype.Field(default_factory=list)
347-
data = Self(to_self={'to_self': {}}, list_self=[{'list_self': []}])
354+
to_self: Optional["LocSelf"] = None
355+
list_self: List["LocSelf"] = utype.Field(default_factory=list)
356+
data = LocSelf(to_self={'to_self': {}}, list_self=[{'list_self': []}])
348357
return data.to_self.to_self.num, data.list_self[0].num
349358

350359
assert f(1) == (1, 1)

tests/test_func.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import utype
1010
from utype import Field, Options, Param, exc, parse, types
11-
from utype.utils.compat import Final
11+
from utype.utils.compat import Final, Self
1212

1313

1414
@pytest.fixture(params=(False, True))
@@ -21,6 +21,22 @@ def on_error(request):
2121
return request.param
2222

2323

24+
class schemas:
25+
class MySchema(utype.Schema):
26+
a: int
27+
b: int
28+
result: int
29+
30+
@classmethod
31+
@utype.parse
32+
def add(cls, a: int, b: int) -> Self:
33+
return dict(
34+
a=a,
35+
b=b,
36+
result=a+b
37+
)
38+
39+
2440
class TestFunc:
2541
def test_basic(self):
2642
import utype
@@ -406,6 +422,13 @@ def fib(n: int = utype.Param(ge=0), _current: int = 0, _next: int = 1):
406422
assert fib('10', _current=10, _next=6) == 55
407423
assert fib('10', 10, 5) == 615 # can pass through positional args
408424

425+
def test_self_ref(self):
426+
result = schemas.MySchema.add('1', '2')
427+
assert isinstance(result, schemas.MySchema)
428+
assert result.a == 1
429+
assert result.b == 2
430+
assert result.result == 3
431+
409432
def test_args_parse(self):
410433
@utype.parse
411434
def get(a):

tests/test_type.py

+1
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def trans_my(trans, d, t):
313313
],
314314
date: [
315315
("2020-02-20", date(2020, 2, 20), True, True),
316+
("20200220", date(2020, 2, 20), True, True),
316317
("2020/02/20", date(2020, 2, 20), True, True),
317318
("2020/2/20", date(2020, 2, 20), True, True),
318319
("20/02/2020", date(2020, 2, 20), True, True),

utype/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
register_transformer = TypeTransformer.registry.register
1313

1414

15-
VERSION = (0, 5, 6, None)
15+
VERSION = (0, 6, 0, 'alpha')
1616

1717

1818
def _get_version():

utype/parser/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def __init__(self, obj, options: Options = None):
8787
def make_context(self, context=None, force_error: bool = False):
8888
return self.options.make_context(context=context, force_error=force_error)
8989

90+
@property
91+
def bound(self):
92+
return self.obj
93+
9094
@property
9195
def kwargs(self):
9296
return {}
@@ -109,7 +113,8 @@ def parse_annotation(self, annotation):
109113
annotation=annotation,
110114
forward_refs=self.forward_refs,
111115
global_vars=self.globals,
112-
force_clear_refs=self.is_local
116+
force_clear_refs=self.is_local,
117+
bound=self.bound
113118
)
114119

115120
@cached_property

utype/parser/cls.py

+2
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def generate_fields(self):
141141
forward_refs=self.forward_refs,
142142
options=self.options,
143143
force_clear_refs=self.is_local,
144+
bound=self.bound,
144145
**self.kwargs
145146
)
146147
except Exception as e:
@@ -185,6 +186,7 @@ def generate_fields(self):
185186
forward_refs=self.forward_refs,
186187
options=self.options,
187188
force_clear_refs=self.is_local,
189+
bound=self.bound,
188190
**self.kwargs
189191
)
190192
except Exception as e:

utype/parser/field.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from uuid import UUID
99

1010
from ..utils import exceptions as exc
11+
from ..utils.base import ParamsCollector
1112
from ..utils.compat import Literal, get_args, is_final, is_annotated, ForwardRef
1213
from ..utils.datastructures import unprovided
1314
from ..utils.functional import copy_value, get_name, multi
@@ -17,7 +18,7 @@
1718
represent = repr
1819

1920

20-
class Field:
21+
class Field(ParamsCollector):
2122
parser_field_cls = None
2223

2324
def __init__(
@@ -91,6 +92,8 @@ def __init__(
9192
min_contains: int = None,
9293
unique_items: Union[bool, ConstraintMode] = None,
9394
):
95+
super().__init__(locals())
96+
9497
if mode:
9598
if readonly or writeonly:
9699
raise exc.ConfigError(
@@ -1094,6 +1097,7 @@ def generate(
10941097
positional_only: bool = False,
10951098
global_vars=None,
10961099
forward_refs=None,
1100+
bound=None,
10971101
force_clear_refs=False,
10981102
**kwargs
10991103
):
@@ -1216,6 +1220,7 @@ def generate(
12161220
global_vars=global_vars,
12171221
forward_refs=forward_refs,
12181222
forward_key=attname,
1223+
bound=bound,
12191224
constraints=output_field.constraints if output_field else None,
12201225
force_clear_refs=force_clear_refs
12211226
)
@@ -1278,6 +1283,7 @@ def generate(
12781283
global_vars=global_vars,
12791284
forward_refs=forward_refs,
12801285
forward_key=attname,
1286+
bound=bound,
12811287
force_clear_refs=force_clear_refs
12821288
)
12831289

utype/parser/func.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ def _f_pass():
3333

3434

3535
class FunctionParser(BaseParser):
36+
@property
37+
def bound(self):
38+
# class A:
39+
# class B:
40+
# def f():
41+
# f.__qualname__ = 'A.B.f'
42+
# f.bound -> 'A.B'
43+
name = self.obj.__qualname__
44+
if '.' in name:
45+
return '.'.join(name.split('.')[:-1])
46+
return None
47+
3648
@classmethod
3749
def function_pass(cls, f):
3850
if not inspect.isfunction(f):
@@ -299,10 +311,12 @@ def generate_return_types(self):
299311
if not self.return_annotation:
300312
return
301313

302-
self.return_type = self.parse_annotation(annotation=self.return_annotation)
314+
self.return_type = self.parse_annotation(
315+
annotation=self.return_annotation
316+
)
303317

304318
# https://docs.python.org/3/library/typing.html#typing.Generator
305-
if self.return_type and issubclass(self.return_type, Rule):
319+
if self.return_type and isinstance(self.return_type, type) and issubclass(self.return_type, Rule):
306320
if self.is_generator:
307321
if self.return_type.__origin__ in (Iterable, Iterator):
308322
self.generator_yield_type = self.return_type.__args__[0]
@@ -406,6 +420,7 @@ def generate_fields(self):
406420
forward_refs=self.forward_refs,
407421
options=self.options,
408422
positional_only=param.kind == param.POSITIONAL_ONLY,
423+
bound=self.bound,
409424
**self.kwargs
410425
)
411426
except Exception as e:
@@ -760,6 +775,7 @@ def get_sync_generator(
760775
@wraps(self.obj)
761776
def eager_generator(*args, **kwargs) -> Generator:
762777
context = (options or self.options).make_context()
778+
self.resolve_forward_refs()
763779
args, kwargs = self.get_params(
764780
args,
765781
kwargs,
@@ -846,6 +862,7 @@ def get_async_generator(
846862
@wraps(self.obj)
847863
def eager_generator(*args, **kwargs) -> AsyncGenerator:
848864
context = (options or self.options).make_context()
865+
self.resolve_forward_refs()
849866
args, kwargs = self.get_params(
850867
args,
851868
kwargs,
@@ -886,6 +903,7 @@ def get_async_call(
886903
@wraps(self.obj)
887904
def eager_call(*args, **kwargs):
888905
context = (options or self.options).make_context()
906+
self.resolve_forward_refs()
889907
args, kwargs = self.get_params(
890908
args,
891909
kwargs,
@@ -915,6 +933,7 @@ def sync_call(
915933
parse_params: bool = None,
916934
parse_result: bool = None,
917935
):
936+
self.resolve_forward_refs()
918937
args, kwargs = self.get_params(
919938
args,
920939
kwargs,

utype/parser/options.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable, List, Optional, Set, Type, Union
44

55
from ..utils import exceptions as exc
6+
# from ..utils.base import ParamsCollector
67
from ..utils.compat import Literal
78
from ..utils.datastructures import unprovided
89
from ..utils.functional import multi
@@ -143,6 +144,7 @@ def __init__(
143144
# if this value is another callable (like dict, list), return value()
144145
# otherwise return this value directly when attr is unprovided
145146
):
147+
# super().__init__({k: v for k, v in locals().items() if not unprovided(v)})
146148

147149
if no_data_loss:
148150
if addition is None:
@@ -182,6 +184,8 @@ def __init__(
182184
for key, val in locals().items():
183185
if unprovided(val):
184186
continue
187+
if key.startswith('_'):
188+
continue
185189
if hasattr(self, key):
186190
# if getattr(self, key) == val:
187191
# continue

0 commit comments

Comments
 (0)