1
1
from llm import AsyncModel , EmbeddingModel , Model , hookimpl
2
2
import llm
3
- from llm .utils import dicts_to_table_string , remove_dict_none_values , logging_client
3
+ from llm .utils import (
4
+ dicts_to_table_string ,
5
+ remove_dict_none_values ,
6
+ logging_client ,
7
+ simplify_usage_dict ,
8
+ )
4
9
import click
5
10
import datetime
6
11
import httpx
@@ -391,6 +396,16 @@ def build_messages(self, prompt, conversation):
391
396
messages .append ({"role" : "user" , "content" : attachment_message })
392
397
return messages
393
398
399
+ def set_usage (self , response , usage ):
400
+ if not usage :
401
+ return
402
+ input_tokens = usage .pop ("prompt_tokens" )
403
+ output_tokens = usage .pop ("completion_tokens" )
404
+ usage .pop ("total_tokens" )
405
+ response .set_usage (
406
+ input = input_tokens , output = output_tokens , details = simplify_usage_dict (usage )
407
+ )
408
+
394
409
def get_client (self , async_ = False ):
395
410
kwargs = {}
396
411
if self .api_base :
@@ -445,6 +460,7 @@ def execute(self, prompt, stream, response, conversation=None):
445
460
messages = self .build_messages (prompt , conversation )
446
461
kwargs = self .build_kwargs (prompt , stream )
447
462
client = self .get_client ()
463
+ usage = None
448
464
if stream :
449
465
completion = client .chat .completions .create (
450
466
model = self .model_name or self .model_id ,
@@ -455,6 +471,8 @@ def execute(self, prompt, stream, response, conversation=None):
455
471
chunks = []
456
472
for chunk in completion :
457
473
chunks .append (chunk )
474
+ if chunk .usage :
475
+ usage = chunk .usage .model_dump ()
458
476
try :
459
477
content = chunk .choices [0 ].delta .content
460
478
except IndexError :
@@ -469,8 +487,10 @@ def execute(self, prompt, stream, response, conversation=None):
469
487
stream = False ,
470
488
** kwargs ,
471
489
)
490
+ usage = completion .usage .model_dump ()
472
491
response .response_json = remove_dict_none_values (completion .model_dump ())
473
492
yield completion .choices [0 ].message .content
493
+ self .set_usage (response , usage )
474
494
response ._prompt_json = redact_data ({"messages" : messages })
475
495
476
496
@@ -493,6 +513,7 @@ async def execute(
493
513
messages = self .build_messages (prompt , conversation )
494
514
kwargs = self .build_kwargs (prompt , stream )
495
515
client = self .get_client (async_ = True )
516
+ usage = None
496
517
if stream :
497
518
completion = await client .chat .completions .create (
498
519
model = self .model_name or self .model_id ,
@@ -502,6 +523,8 @@ async def execute(
502
523
)
503
524
chunks = []
504
525
async for chunk in completion :
526
+ if chunk .usage :
527
+ usage = chunk .usage .model_dump ()
505
528
chunks .append (chunk )
506
529
try :
507
530
content = chunk .choices [0 ].delta .content
@@ -518,7 +541,9 @@ async def execute(
518
541
** kwargs ,
519
542
)
520
543
response .response_json = remove_dict_none_values (completion .model_dump ())
544
+ usage = completion .usage .model_dump ()
521
545
yield completion .choices [0 ].message .content
546
+ self .set_usage (response , usage )
522
547
response ._prompt_json = redact_data ({"messages" : messages })
523
548
524
549
0 commit comments