Skip to content

Commit 80956da

Browse files
committed
Store input_tokens, output_tokens, token_details on Response, refs #610
1 parent 4a059d7 commit 80956da

File tree

8 files changed

+97
-3
lines changed

8 files changed

+97
-3
lines changed

llm/cli.py

+3
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,9 @@ def logs_turn_off():
754754
responses.conversation_id,
755755
responses.duration_ms,
756756
responses.datetime_utc,
757+
responses.input_tokens,
758+
responses.output_tokens,
759+
responses.token_details,
757760
conversations.name as conversation_name,
758761
conversations.model as conversation_model"""
759762

llm/default_plugins/openai_models.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from llm import AsyncModel, EmbeddingModel, Model, hookimpl
22
import llm
3-
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
3+
from llm.utils import (
4+
dicts_to_table_string,
5+
remove_dict_none_values,
6+
logging_client,
7+
simplify_usage_dict,
8+
)
49
import click
510
import datetime
611
import httpx
@@ -391,6 +396,16 @@ def build_messages(self, prompt, conversation):
391396
messages.append({"role": "user", "content": attachment_message})
392397
return messages
393398

399+
def set_usage(self, response, usage):
400+
if not usage:
401+
return
402+
input_tokens = usage.pop("prompt_tokens")
403+
output_tokens = usage.pop("completion_tokens")
404+
usage.pop("total_tokens")
405+
response.set_usage(
406+
input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)
407+
)
408+
394409
def get_client(self, async_=False):
395410
kwargs = {}
396411
if self.api_base:
@@ -445,6 +460,7 @@ def execute(self, prompt, stream, response, conversation=None):
445460
messages = self.build_messages(prompt, conversation)
446461
kwargs = self.build_kwargs(prompt, stream)
447462
client = self.get_client()
463+
usage = None
448464
if stream:
449465
completion = client.chat.completions.create(
450466
model=self.model_name or self.model_id,
@@ -455,6 +471,8 @@ def execute(self, prompt, stream, response, conversation=None):
455471
chunks = []
456472
for chunk in completion:
457473
chunks.append(chunk)
474+
if chunk.usage:
475+
usage = chunk.usage.model_dump()
458476
try:
459477
content = chunk.choices[0].delta.content
460478
except IndexError:
@@ -469,8 +487,10 @@ def execute(self, prompt, stream, response, conversation=None):
469487
stream=False,
470488
**kwargs,
471489
)
490+
usage = completion.usage.model_dump()
472491
response.response_json = remove_dict_none_values(completion.model_dump())
473492
yield completion.choices[0].message.content
493+
self.set_usage(response, usage)
474494
response._prompt_json = redact_data({"messages": messages})
475495

476496

@@ -493,6 +513,7 @@ async def execute(
493513
messages = self.build_messages(prompt, conversation)
494514
kwargs = self.build_kwargs(prompt, stream)
495515
client = self.get_client(async_=True)
516+
usage = None
496517
if stream:
497518
completion = await client.chat.completions.create(
498519
model=self.model_name or self.model_id,
@@ -502,6 +523,8 @@ async def execute(
502523
)
503524
chunks = []
504525
async for chunk in completion:
526+
if chunk.usage:
527+
usage = chunk.usage.model_dump()
505528
chunks.append(chunk)
506529
try:
507530
content = chunk.choices[0].delta.content
@@ -518,7 +541,9 @@ async def execute(
518541
**kwargs,
519542
)
520543
response.response_json = remove_dict_none_values(completion.model_dump())
544+
usage = completion.usage.model_dump()
521545
yield completion.choices[0].message.content
546+
self.set_usage(response, usage)
522547
response._prompt_json = redact_data({"messages": messages})
523548

524549

llm/migrations.py

+7
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,10 @@ def m012_attachments_tables(db):
227227
),
228228
pk=("response_id", "attachment_id"),
229229
)
230+
231+
232+
@migration
233+
def m013_usage(db):
234+
db["responses"].add_column("input_tokens", int)
235+
db["responses"].add_column("output_tokens", int)
236+
db["responses"].add_column("token_details", str)

llm/models.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,20 @@ def __init__(
208208
self._start: Optional[float] = None
209209
self._end: Optional[float] = None
210210
self._start_utcnow: Optional[datetime.datetime] = None
211+
self.input_tokens: Optional[int] = None
212+
self.output_tokens: Optional[int] = None
213+
self.token_details: Optional[dict] = None
214+
215+
def set_usage(
216+
self,
217+
*,
218+
input: Optional[int] = None,
219+
output: Optional[int] = None,
220+
details: Optional[dict] = None,
221+
):
222+
self.input_tokens = input
223+
self.output_tokens = output
224+
self.token_details = details
211225

212226
@classmethod
213227
def from_row(cls, db, row):
@@ -272,11 +286,16 @@ def log_to_db(self, db):
272286
for key, value in dict(self.prompt.options).items()
273287
if value is not None
274288
},
275-
"response": self.text(),
289+
"response": self.text_or_raise(),
276290
"response_json": self.json(),
277291
"conversation_id": conversation.id,
278292
"duration_ms": self.duration_ms(),
279293
"datetime_utc": self.datetime_utc(),
294+
"input_tokens": self.input_tokens,
295+
"output_tokens": self.input_tokens,
296+
"token_details": (
297+
json.dumps(self.token_details) if self.token_details else None
298+
),
280299
}
281300
db["responses"].insert(response)
282301
# Persist any attachments - loop through with index
@@ -439,6 +458,9 @@ async def to_sync_response(self) -> Response:
439458
response._end = self._end
440459
response._start = self._start
441460
response._start_utcnow = self._start_utcnow
461+
response.input_tokens = self.input_tokens
462+
response.output_tokens = self.output_tokens
463+
response.token_details = self.token_details
442464
return response
443465

444466
@classmethod

llm/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,18 @@ def logging_client() -> httpx.Client:
127127
transport=_LogTransport(httpx.HTTPTransport()),
128128
event_hooks={"request": [_no_accept_encoding], "response": [_log_response]},
129129
)
130+
131+
132+
def simplify_usage_dict(d):
133+
# Recursively remove keys with value 0 and empty dictionaries
134+
def remove_empty_and_zero(obj):
135+
if isinstance(obj, dict):
136+
cleaned = {
137+
k: remove_empty_and_zero(v)
138+
for k, v in obj.items()
139+
if v != 0 and v != {}
140+
}
141+
return {k: v for k, v in cleaned.items() if v is not None and v != {}}
142+
return obj
143+
144+
return remove_empty_and_zero(d) or {}

tests/conftest.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,17 @@ def enqueue(self, messages):
6666

6767
def execute(self, prompt, stream, response, conversation):
6868
self.history.append((prompt, stream, response, conversation))
69+
gathered = []
6970
while True:
7071
try:
7172
messages = self._queue.pop(0)
72-
yield from messages
73+
for message in messages:
74+
gathered.append(message)
75+
yield message
7376
break
7477
except IndexError:
7578
break
79+
response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))
7680

7781

7882
class AsyncMockModel(llm.AsyncModel):

tests/test_chat.py

+15
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def test_chat_basic(mock_model, logs_db):
6262
"conversation_id": conversation_id,
6363
"duration_ms": ANY,
6464
"datetime_utc": ANY,
65+
"input_tokens": 1,
66+
"output_tokens": 1,
67+
"token_details": None,
6568
},
6669
{
6770
"id": ANY,
@@ -75,6 +78,9 @@ def test_chat_basic(mock_model, logs_db):
7578
"conversation_id": conversation_id,
7679
"duration_ms": ANY,
7780
"datetime_utc": ANY,
81+
"input_tokens": 2,
82+
"output_tokens": 2,
83+
"token_details": None,
7884
},
7985
]
8086
# Now continue that conversation
@@ -116,6 +122,9 @@ def test_chat_basic(mock_model, logs_db):
116122
"conversation_id": conversation_id,
117123
"duration_ms": ANY,
118124
"datetime_utc": ANY,
125+
"input_tokens": 1,
126+
"output_tokens": 1,
127+
"token_details": None,
119128
}
120129
]
121130

@@ -153,6 +162,9 @@ def test_chat_system(mock_model, logs_db):
153162
"conversation_id": ANY,
154163
"duration_ms": ANY,
155164
"datetime_utc": ANY,
165+
"input_tokens": 1,
166+
"output_tokens": 1,
167+
"token_details": None,
156168
}
157169
]
158170

@@ -181,6 +193,9 @@ def test_chat_options(mock_model, logs_db):
181193
"conversation_id": ANY,
182194
"duration_ms": ANY,
183195
"datetime_utc": ANY,
196+
"input_tokens": 1,
197+
"output_tokens": 1,
198+
"token_details": None,
184199
}
185200
]
186201

tests/test_migrate.py

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
"conversation_id": str,
1818
"duration_ms": int,
1919
"datetime_utc": str,
20+
"input_tokens": int,
21+
"output_tokens": int,
22+
"token_details": str,
2023
}
2124

2225

0 commit comments

Comments
 (0)