Skip to content

Commit

Permalink
feat: tools (#1180)
Browse files Browse the repository at this point in the history
* feat: tools

* feat: claude use tools

* feat: openai use tools
  • Loading branch information
yetone authored Feb 5, 2025
1 parent 1726d32 commit 1437f31
Show file tree
Hide file tree
Showing 17 changed files with 1,321 additions and 74 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/lua.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ jobs:
mkdir -p _neovim
curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.rev }}" | tar xzf - --strip-components=1 -C "${PWD}/_neovim"
}
sudo apt-get update
sudo apt-get install -y ripgrep
sudo apt-get install -y silversearcher-ag
- name: Run tests
run: |
Expand Down
2 changes: 2 additions & 0 deletions crates/avante-templates/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct TemplateContext {
selected_code: Option<String>,
project_context: Option<String>,
diagnostics: Option<String>,
system_info: Option<String>,
}

// Given the file name registered after add, the context table in Lua, resulted in a formatted
Expand All @@ -54,6 +55,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
selected_code => context.selected_code,
project_context => context.project_context,
diagnostics => context.diagnostics,
system_info => context.system_info,
})
.map_err(LuaError::external)
.unwrap())
Expand Down
8 changes: 8 additions & 0 deletions lua/avante/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ M._defaults = {
-- For most providers that we support we will determine this automatically.
-- If you wish to use a given implementation, then you can override it here.
tokenizer = "tiktoken",
web_search_engine = {
provider = "tavily",
api_key_name = "TAVILY_API_KEY",
provider_opts = {
time_range = "d",
include_answer = "basic",
},
},
---@type AvanteSupportedProvider
openai = {
endpoint = "https://api.openai.com/v1",
Expand Down
67 changes: 51 additions & 16 deletions lua/avante/llm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ local Utils = require("avante.utils")
local Config = require("avante.config")
local Path = require("avante.path")
local P = require("avante.providers")
local LLMTools = require("avante.llm_tools")

---@class avante.LLM
local M = {}
Expand Down Expand Up @@ -45,6 +46,8 @@ M.generate_prompts = function(opts)
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get(project_root))

local system_info = Utils.get_system_info()

local template_opts = {
use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction
Expand All @@ -53,6 +56,7 @@ M.generate_prompts = function(opts)
selected_code = opts.selected_code,
project_context = opts.project_context,
diagnostics = opts.diagnostics,
system_info = system_info,
}

local system_prompt = Path.prompts.render_mode(mode, template_opts)
Expand Down Expand Up @@ -111,6 +115,10 @@ M.generate_prompts = function(opts)
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
tools = opts.tools,
tool_use = opts.tool_use,
tool_result = opts.tool_result,
response_content = opts.response_content,
}
end

Expand All @@ -135,7 +143,28 @@ M._stream = function(opts)
local current_event_state = nil

---@type AvanteHandlerOptions
local handler_opts = { on_chunk = opts.on_chunk, on_complete = opts.on_complete }
local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
if stop_opts.reason == "tool_use" and stop_opts.tool_use then
local result, error = LLMTools.process_tool_use(stop_opts.tool_use)
local tool_result = {
tool_use_id = stop_opts.tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
local new_opts = vim.tbl_deep_extend(
"force",
opts,
{ tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content }
)
return M._stream(new_opts)
end
return opts.on_stop(stop_opts)
end,
}

---@type AvanteCurlOutput
local spec = Provider.parse_curl_args(Provider, code_opts)

Expand Down Expand Up @@ -180,7 +209,7 @@ M._stream = function(opts)
stream = function(err, data, _)
if err then
completed = true
opts.on_complete(err)
handler_opts.on_stop({ reason = "error", error = err })
return
end
if not data then return end
Expand Down Expand Up @@ -224,7 +253,7 @@ M._stream = function(opts)
active_job = nil
completed = true
cleanup()
opts.on_complete(err)
handler_opts.on_stop({ reason = "error", error = err })
end,
callback = function(result)
active_job = nil
Expand All @@ -238,9 +267,10 @@ M._stream = function(opts)
vim.schedule(function()
if not completed then
completed = true
opts.on_complete(
"API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body)
)
handler_opts.on_stop({
reason = "error",
error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
})
end
end)
end
Expand Down Expand Up @@ -335,9 +365,9 @@ M._dual_boost_stream = function(opts, Provider1, Provider2)
on_chunk = function(chunk)
if chunk then response = response .. chunk end
end,
on_complete = function(err)
if err then
Utils.error(string.format("Stream %d failed: %s", index, err))
on_stop = function(stop_opts)
if stop_opts.error then
Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error))
return
end
Utils.debug(string.format("Response %d completed", index))
Expand Down Expand Up @@ -381,10 +411,15 @@ end
---@field instructions string
---@field mode LlmMode
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
---@field tools? AvanteLLMTool[]
---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
---
---@class StreamOptions: GeneratePromptsOptions
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
---@field on_start AvanteLLMStartCallback
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback

---@param opts StreamOptions
M.stream = function(opts)
Expand All @@ -396,12 +431,12 @@ M.stream = function(opts)
return original_on_chunk(chunk)
end)
end
if opts.on_complete ~= nil then
local original_on_complete = opts.on_complete
opts.on_complete = vim.schedule_wrap(function(err)
if opts.on_stop ~= nil then
local original_on_stop = opts.on_stop
opts.on_stop = vim.schedule_wrap(function(stop_opts)
if is_completed then return end
is_completed = true
return original_on_complete(err)
if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end
return original_on_stop(stop_opts)
end)
end
if Config.dual_boost.enabled then
Expand Down
Loading

0 comments on commit 1437f31

Please sign in to comment.