Skip to content

Commit

Permalink
fix: pass context to provider for stream data parsing (#1475)
Browse files Browse the repository at this point in the history
* fix: pass context to provider for stream data parsing

* fix: luatype

---------

Co-authored-by: yetone <yetoneful@gmail.com>
  • Loading branch information
brookhong and yetone authored Mar 3, 2025
1 parent c6d5527 commit 6bd966e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions lua/avante/llm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ end
function M._stream(opts)
local provider = opts.provider or Providers[Config.provider]

---@cast provider AvanteProviderFunctor

local prompt_opts = M.generate_prompts(opts)

---@type string
Expand Down Expand Up @@ -285,10 +287,10 @@ function M._stream(opts)
{ once = true }
)
end
provider.parse_stream_data(data, handler_opts)
provider.parse_stream_data(resp_ctx, data, handler_opts)
else
if provider.parse_stream_data ~= nil then
provider.parse_stream_data(data, handler_opts)
provider.parse_stream_data(resp_ctx, data, handler_opts)
else
parse_stream_data(data)
end
Expand Down
4 changes: 2 additions & 2 deletions lua/avante/providers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function M.build_bedrock_payload(prompt_opts, body_opts)
return model_handler.build_bedrock_payload(prompt_opts, body_opts)
end

function M.parse_stream_data(data, opts)
function M.parse_stream_data(ctx, data, opts)
-- @NOTE: Decode and process Bedrock response
-- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON.
-- The `type` field in the decoded JSON determines how the response is handled.
Expand All @@ -37,7 +37,7 @@ function M.parse_stream_data(data, opts)
local jsn = vim.json.decode(bedrock_data_match)
local data_stream = vim.base64.decode(jsn.bytes)
local json = vim.json.decode(data_stream)
M.parse_response({}, data_stream, json.type, opts)
M.parse_response(ctx, data_stream, json.type, opts)
end
end

Expand Down
2 changes: 1 addition & 1 deletion lua/avante/providers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function M.parse_messages(opts)
return { messages = messages }
end

function M.parse_stream_data(data, opts)
function M.parse_stream_data(ctx, data, opts)
---@type CohereChatResponse
local json = vim.json.decode(data)
if json.type ~= nil then
Expand Down
2 changes: 1 addition & 1 deletion lua/avante/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ vim.g.avante_login = vim.g.avante_login
---@field tool_use_list? AvanteLLMToolUse[]
---@field retry_after? integer
---
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteStreamParser fun(ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
---@alias AvanteLLMChunkCallback fun(chunk: string): any
---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil
Expand Down

0 comments on commit 6bd966e

Please sign in to comment.