Skip to content

Commit 21b024a

Browse files
fantixtailhook
andauthored
Add optional default to generated params (#426)
Co-authored-by: Paul Colomiets <paul@colomiets.name>
1 parent bb7522c commit 21b024a

File tree

5 files changed

+92
-80
lines changed

5 files changed

+92
-80
lines changed

edgedb/codegen/generator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _generate(
325325
kw_only = True
326326
for el_name, el in dr.input_type.elements.items():
327327
args[el_name] = self._generate_code_with_cardinality(
328-
el.type, el_name, el.cardinality
328+
el.type, el_name, el.cardinality, keyword_argument=True
329329
)
330330

331331
if self._async:
@@ -502,6 +502,7 @@ def _generate_code_with_cardinality(
502502
type_: typing.Optional[describe.AnyType],
503503
name_hint: str,
504504
cardinality: edgedb.Cardinality,
505+
keyword_argument: bool = False,
505506
):
506507
rv = self._generate_code(type_, name_hint)
507508
if cardinality == edgedb.Cardinality.AT_MOST_ONE:
@@ -510,6 +511,8 @@ def _generate_code_with_cardinality(
510511
else:
511512
self._imports.add("typing")
512513
rv = f"typing.Optional[{rv}]"
514+
if keyword_argument:
515+
rv = f"{rv} = None"
513516
return rv
514517

515518
def _find_name(self, name: str) -> str:

tests/codegen/test-project2/generated_async_edgeql.py.assert

+25-25
Original file line numberDiff line numberDiff line change
@@ -165,55 +165,55 @@ async def my_query(
165165
executor: edgedb.AsyncIOExecutor,
166166
*,
167167
a: uuid.UUID,
168-
b: uuid.UUID | None,
168+
b: uuid.UUID | None = None,
169169
c: str,
170-
d: str | None,
170+
d: str | None = None,
171171
e: bytes,
172-
f: bytes | None,
172+
f: bytes | None = None,
173173
g: int,
174-
h: int | None,
174+
h: int | None = None,
175175
i: int,
176-
j: int | None,
176+
j: int | None = None,
177177
k: int,
178-
l: int | None,
178+
l: int | None = None,
179179
m: float,
180-
n: float | None,
180+
n: float | None = None,
181181
o: float,
182-
p: float | None,
182+
p: float | None = None,
183183
q: bool,
184-
r: bool | None,
184+
r: bool | None = None,
185185
s: datetime.datetime,
186-
t: datetime.datetime | None,
186+
t: datetime.datetime | None = None,
187187
u: datetime.datetime,
188-
v: datetime.datetime | None,
188+
v: datetime.datetime | None = None,
189189
w: datetime.date,
190-
x: datetime.date | None,
190+
x: datetime.date | None = None,
191191
y: datetime.time,
192-
z: datetime.time | None,
192+
z: datetime.time | None = None,
193193
aa: datetime.timedelta,
194-
ab: datetime.timedelta | None,
194+
ab: datetime.timedelta | None = None,
195195
ac: int,
196-
ad: int | None,
196+
ad: int | None = None,
197197
ae: edgedb.RelativeDuration,
198-
af: edgedb.RelativeDuration | None,
198+
af: edgedb.RelativeDuration | None = None,
199199
ag: edgedb.DateDuration,
200-
ah: edgedb.DateDuration | None,
200+
ah: edgedb.DateDuration | None = None,
201201
ai: edgedb.ConfigMemory,
202-
aj: edgedb.ConfigMemory | None,
202+
aj: edgedb.ConfigMemory | None = None,
203203
ak: edgedb.Range[int],
204-
al: edgedb.Range[int] | None,
204+
al: edgedb.Range[int] | None = None,
205205
am: edgedb.Range[int],
206-
an: edgedb.Range[int] | None,
206+
an: edgedb.Range[int] | None = None,
207207
ao: edgedb.Range[float],
208-
ap: edgedb.Range[float] | None,
208+
ap: edgedb.Range[float] | None = None,
209209
aq: edgedb.Range[float],
210-
ar: edgedb.Range[float] | None,
210+
ar: edgedb.Range[float] | None = None,
211211
as_: edgedb.Range[datetime.datetime],
212-
at: edgedb.Range[datetime.datetime] | None,
212+
at: edgedb.Range[datetime.datetime] | None = None,
213213
au: edgedb.Range[datetime.datetime],
214-
av: edgedb.Range[datetime.datetime] | None,
214+
av: edgedb.Range[datetime.datetime] | None = None,
215215
aw: edgedb.Range[datetime.date],
216-
ax: edgedb.Range[datetime.date] | None,
216+
ax: edgedb.Range[datetime.date] | None = None,
217217
) -> MyQueryResult:
218218
return await executor.query_single(
219219
"""\

tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert

+25-25
Original file line numberDiff line numberDiff line change
@@ -91,55 +91,55 @@ async def my_query(
9191
executor: edgedb.AsyncIOExecutor,
9292
*,
9393
a: uuid.UUID,
94-
b: typing.Optional[uuid.UUID],
94+
b: typing.Optional[uuid.UUID] = None,
9595
c: str,
96-
d: typing.Optional[str],
96+
d: typing.Optional[str] = None,
9797
e: bytes,
98-
f: typing.Optional[bytes],
98+
f: typing.Optional[bytes] = None,
9999
g: int,
100-
h: typing.Optional[int],
100+
h: typing.Optional[int] = None,
101101
i: int,
102-
j: typing.Optional[int],
102+
j: typing.Optional[int] = None,
103103
k: int,
104-
l: typing.Optional[int],
104+
l: typing.Optional[int] = None,
105105
m: float,
106-
n: typing.Optional[float],
106+
n: typing.Optional[float] = None,
107107
o: float,
108-
p: typing.Optional[float],
108+
p: typing.Optional[float] = None,
109109
q: bool,
110-
r: typing.Optional[bool],
110+
r: typing.Optional[bool] = None,
111111
s: datetime.datetime,
112-
t: typing.Optional[datetime.datetime],
112+
t: typing.Optional[datetime.datetime] = None,
113113
u: datetime.datetime,
114-
v: typing.Optional[datetime.datetime],
114+
v: typing.Optional[datetime.datetime] = None,
115115
w: datetime.date,
116-
x: typing.Optional[datetime.date],
116+
x: typing.Optional[datetime.date] = None,
117117
y: datetime.time,
118-
z: typing.Optional[datetime.time],
118+
z: typing.Optional[datetime.time] = None,
119119
aa: datetime.timedelta,
120-
ab: typing.Optional[datetime.timedelta],
120+
ab: typing.Optional[datetime.timedelta] = None,
121121
ac: int,
122-
ad: typing.Optional[int],
122+
ad: typing.Optional[int] = None,
123123
ae: edgedb.RelativeDuration,
124-
af: typing.Optional[edgedb.RelativeDuration],
124+
af: typing.Optional[edgedb.RelativeDuration] = None,
125125
ag: edgedb.DateDuration,
126-
ah: typing.Optional[edgedb.DateDuration],
126+
ah: typing.Optional[edgedb.DateDuration] = None,
127127
ai: edgedb.ConfigMemory,
128-
aj: typing.Optional[edgedb.ConfigMemory],
128+
aj: typing.Optional[edgedb.ConfigMemory] = None,
129129
ak: edgedb.Range[int],
130-
al: typing.Optional[edgedb.Range[int]],
130+
al: typing.Optional[edgedb.Range[int]] = None,
131131
am: edgedb.Range[int],
132-
an: typing.Optional[edgedb.Range[int]],
132+
an: typing.Optional[edgedb.Range[int]] = None,
133133
ao: edgedb.Range[float],
134-
ap: typing.Optional[edgedb.Range[float]],
134+
ap: typing.Optional[edgedb.Range[float]] = None,
135135
aq: edgedb.Range[float],
136-
ar: typing.Optional[edgedb.Range[float]],
136+
ar: typing.Optional[edgedb.Range[float]] = None,
137137
as_: edgedb.Range[datetime.datetime],
138-
at: typing.Optional[edgedb.Range[datetime.datetime]],
138+
at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
139139
au: edgedb.Range[datetime.datetime],
140-
av: typing.Optional[edgedb.Range[datetime.datetime]],
140+
av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
141141
aw: edgedb.Range[datetime.date],
142-
ax: typing.Optional[edgedb.Range[datetime.date]],
142+
ax: typing.Optional[edgedb.Range[datetime.date]] = None,
143143
) -> MyQueryResult:
144144
return await executor.query_single(
145145
"""\

tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert

+25-25
Original file line numberDiff line numberDiff line change
@@ -82,55 +82,55 @@ def my_query(
8282
executor: edgedb.Executor,
8383
*,
8484
a: uuid.UUID,
85-
b: typing.Optional[uuid.UUID],
85+
b: typing.Optional[uuid.UUID] = None,
8686
c: str,
87-
d: typing.Optional[str],
87+
d: typing.Optional[str] = None,
8888
e: bytes,
89-
f: typing.Optional[bytes],
89+
f: typing.Optional[bytes] = None,
9090
g: int,
91-
h: typing.Optional[int],
91+
h: typing.Optional[int] = None,
9292
i: int,
93-
j: typing.Optional[int],
93+
j: typing.Optional[int] = None,
9494
k: int,
95-
l: typing.Optional[int],
95+
l: typing.Optional[int] = None,
9696
m: float,
97-
n: typing.Optional[float],
97+
n: typing.Optional[float] = None,
9898
o: float,
99-
p: typing.Optional[float],
99+
p: typing.Optional[float] = None,
100100
q: bool,
101-
r: typing.Optional[bool],
101+
r: typing.Optional[bool] = None,
102102
s: datetime.datetime,
103-
t: typing.Optional[datetime.datetime],
103+
t: typing.Optional[datetime.datetime] = None,
104104
u: datetime.datetime,
105-
v: typing.Optional[datetime.datetime],
105+
v: typing.Optional[datetime.datetime] = None,
106106
w: datetime.date,
107-
x: typing.Optional[datetime.date],
107+
x: typing.Optional[datetime.date] = None,
108108
y: datetime.time,
109-
z: typing.Optional[datetime.time],
109+
z: typing.Optional[datetime.time] = None,
110110
aa: datetime.timedelta,
111-
ab: typing.Optional[datetime.timedelta],
111+
ab: typing.Optional[datetime.timedelta] = None,
112112
ac: int,
113-
ad: typing.Optional[int],
113+
ad: typing.Optional[int] = None,
114114
ae: edgedb.RelativeDuration,
115-
af: typing.Optional[edgedb.RelativeDuration],
115+
af: typing.Optional[edgedb.RelativeDuration] = None,
116116
ag: edgedb.DateDuration,
117-
ah: typing.Optional[edgedb.DateDuration],
117+
ah: typing.Optional[edgedb.DateDuration] = None,
118118
ai: edgedb.ConfigMemory,
119-
aj: typing.Optional[edgedb.ConfigMemory],
119+
aj: typing.Optional[edgedb.ConfigMemory] = None,
120120
ak: edgedb.Range[int],
121-
al: typing.Optional[edgedb.Range[int]],
121+
al: typing.Optional[edgedb.Range[int]] = None,
122122
am: edgedb.Range[int],
123-
an: typing.Optional[edgedb.Range[int]],
123+
an: typing.Optional[edgedb.Range[int]] = None,
124124
ao: edgedb.Range[float],
125-
ap: typing.Optional[edgedb.Range[float]],
125+
ap: typing.Optional[edgedb.Range[float]] = None,
126126
aq: edgedb.Range[float],
127-
ar: typing.Optional[edgedb.Range[float]],
127+
ar: typing.Optional[edgedb.Range[float]] = None,
128128
as_: edgedb.Range[datetime.datetime],
129-
at: typing.Optional[edgedb.Range[datetime.datetime]],
129+
at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
130130
au: edgedb.Range[datetime.datetime],
131-
av: typing.Optional[edgedb.Range[datetime.datetime]],
131+
av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
132132
aw: edgedb.Range[datetime.date],
133-
ax: typing.Optional[edgedb.Range[datetime.date]],
133+
ax: typing.Optional[edgedb.Range[datetime.date]] = None,
134134
) -> MyQueryResult:
135135
return executor.query_single(
136136
"""\

tests/test_codegen.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ async def test_codegen(self):
4343
for project in container.iterdir():
4444
if project.name == "linked":
4545
continue
46-
cwd = td_path / project.name
47-
shutil.copytree(project, cwd)
48-
await self._test_codegen(env, cwd)
46+
with self.subTest(msg=project.name):
47+
cwd = td_path / project.name
48+
shutil.copytree(project, cwd)
49+
try:
50+
await self._test_codegen(env, cwd)
51+
except subprocess.CalledProcessError as e:
52+
self.fail("Codegen failed: " + e.stdout.decode())
4953

5054
async def _test_codegen(self, env, cwd: pathlib.Path):
5155
async def run(*args, extra_env=None):
@@ -67,6 +71,11 @@ async def run(*args, extra_env=None):
6771
p.terminate()
6872
await p.wait()
6973
raise
74+
else:
75+
if p.returncode:
76+
raise subprocess.CalledProcessError(
77+
p.returncode, args, output=await p.stdout.read(),
78+
)
7079

7180
cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "edgedb-py")
7281
await run(
@@ -90,7 +99,7 @@ async def run(*args, extra_env=None):
9099

91100
for f in cwd.rglob("*.py"):
92101
a = f.with_suffix(".py.assert")
93-
self.assertEqual(f.read_text(), a.read_text())
102+
self.assertEqual(f.read_text(), a.read_text(), msg=f.name)
94103
for a in cwd.rglob("*.py.assert"):
95104
f = a.with_suffix("")
96105
self.assertTrue(f.exists(), f"{f} doesn't exist")

0 commit comments

Comments
 (0)