Skip to content

Commit 4a059d7

Browse files
committed
Log --async responses to DB, closes #641
Refs #507
1 parent a6d62b7 commit 4a059d7

File tree

3 files changed

+76
-4
lines changed

3 files changed

+76
-4
lines changed

llm/cli.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
from llm import (
88
Attachment,
9+
AsyncResponse,
910
Collection,
1011
Conversation,
1112
Response,
@@ -376,6 +377,7 @@ def read_prompt():
376377
validated_options["stream"] = False
377378

378379
prompt = read_prompt()
380+
response = None
379381

380382
prompt_method = model.prompt
381383
if conversation:
@@ -386,12 +388,13 @@ def read_prompt():
386388

387389
async def inner():
388390
if should_stream:
389-
async for chunk in prompt_method(
391+
response = prompt_method(
390392
prompt,
391393
attachments=resolved_attachments,
392394
system=system,
393395
**validated_options,
394-
):
396+
)
397+
async for chunk in response:
395398
print(chunk, end="")
396399
sys.stdout.flush()
397400
print("")
@@ -403,8 +406,9 @@ async def inner():
403406
**validated_options,
404407
)
405408
print(await response.text())
409+
return response
406410

407-
asyncio.run(inner())
411+
response = asyncio.run(inner())
408412
else:
409413
response = prompt_method(
410414
prompt,
@@ -423,11 +427,13 @@ async def inner():
423427
raise click.ClickException(str(ex))
424428

425429
# Log to the database
426-
if (logs_on() or log) and not no_log and not async_:
430+
if (logs_on() or log) and not no_log:
427431
log_path = logs_db_path()
428432
(log_path.parent).mkdir(parents=True, exist_ok=True)
429433
db = sqlite_utils.Database(log_path)
430434
migrate(db)
435+
if isinstance(response, AsyncResponse):
436+
response = asyncio.run(response.to_sync_response())
431437
response.log_to_db(db)
432438

433439

llm/models.py

+15
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,21 @@ async def datetime_utc(self) -> str:
426426
def __await__(self):
427427
return self._force().__await__()
428428

429+
async def to_sync_response(self) -> Response:
430+
await self._force()
431+
response = Response(
432+
self.prompt,
433+
self.model,
434+
self.stream,
435+
conversation=self.conversation,
436+
)
437+
response._chunks = self._chunks
438+
response._done = True
439+
response._end = self._end
440+
response._start = self._start
441+
response._start_utcnow = self._start_utcnow
442+
return response
443+
429444
@classmethod
430445
def fake(
431446
cls,

tests/test_cli_openai_models.py

+51
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from click.testing import CliRunner
22
from llm.cli import cli
33
import pytest
4+
import sqlite_utils
45

56

67
@pytest.fixture
@@ -143,3 +144,53 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype):
143144
assert (
144145
f"This model does not support attachments of type '{long}'" in result.output
145146
)
147+
148+
149+
@pytest.mark.parametrize("async_", (False, True))
150+
def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_):
151+
user_path = tmpdir / "user_dir"
152+
log_db = user_path / "logs.db"
153+
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
154+
assert not log_db.exists()
155+
httpx_mock.add_response(
156+
method="POST",
157+
# chat completion request
158+
url="https://api.openai.com/v1/chat/completions",
159+
json={
160+
"id": "chatcmpl-AQT9a30kxEaM1bqxRPepQsPlCyGJh",
161+
"object": "chat.completion",
162+
"created": 1730871958,
163+
"model": "gpt-4o-mini",
164+
"choices": [
165+
{
166+
"index": 0,
167+
"message": {
168+
"role": "assistant",
169+
"content": "Ho ho ho",
170+
"refusal": None,
171+
},
172+
"finish_reason": "stop",
173+
}
174+
],
175+
"usage": {
176+
"prompt_tokens": 10,
177+
"completion_tokens": 2,
178+
"total_tokens": 12,
179+
},
180+
"system_fingerprint": "fp_49254d0e9b",
181+
},
182+
headers={"Content-Type": "application/json"},
183+
)
184+
runner = CliRunner()
185+
args = ["-m", "gpt-4o-mini", "--key", "x", "--no-stream"]
186+
if async_:
187+
args.append("--async")
188+
result = runner.invoke(cli, args, catch_exceptions=False)
189+
assert result.exit_code == 0
190+
assert result.output == "Ho ho ho\n"
191+
# Confirm it was correctly logged
192+
assert log_db.exists()
193+
db = sqlite_utils.Database(str(log_db))
194+
assert db["responses"].count == 1
195+
row = next(db["responses"].rows)
196+
assert row["response"] == "Ho ho ho"

0 commit comments

Comments
 (0)