Skip to content

Commit 517ae3d

Browse files
committed
Initial support for schemas, refs #49
Refs simonw/llm#776
1 parent 1db6c5f commit 517ae3d

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

llm_gemini.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import httpx
23
import ijson
34
import llm
@@ -79,10 +80,26 @@ def resolve_type(attachment):
7980
return mime_type
8081

8182

83+
def cleanup_schema(schema):
84+
"Gemini supports only a subset of JSON schema"
85+
keys_to_remove = ("$schema", "additionalProperties")
86+
# Recursively remove them
87+
if isinstance(schema, dict):
88+
for key in keys_to_remove:
89+
schema.pop(key, None)
90+
for value in schema.values():
91+
cleanup_schema(value)
92+
elif isinstance(schema, list):
93+
for value in schema:
94+
cleanup_schema(value)
95+
return schema
96+
97+
8298
class _SharedGemini:
8399
needs_key = "gemini"
84100
key_env_var = "LLM_GEMINI_KEY"
85101
can_stream = True
102+
supports_schema = True
86103

87104
attachment_types = (
88105
# Text
@@ -226,6 +243,12 @@ def build_request_body(self, prompt, conversation):
226243
if prompt.system:
227244
body["systemInstruction"] = {"parts": [{"text": prompt.system}]}
228245

246+
if prompt.schema:
247+
body["generationConfig"] = {
248+
"response_mime_type": "application/json",
249+
"response_schema": cleanup_schema(copy.deepcopy(prompt.schema)),
250+
}
251+
229252
config_map = {
230253
"temperature": "temperature",
231254
"max_output_tokens": "maxOutputTokens",

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ classifiers = [
99
"License :: OSI Approved :: Apache Software License"
1010
]
1111
dependencies = [
12-
"llm>=0.22",
12+
"llm>=0.23a0",
1313
"httpx",
1414
"ijson"
1515
]

0 commit comments

Comments
 (0)