Skip to content

Commit 067c59e

Browse files
authored
Merge pull request #516 from matengm1/refactor/standardize-tool-choice-literals
Standardize literals for role and tool choice type definitions
2 parents c6cd296 + 6b64b98 commit 067c59e

File tree

6 files changed

+45
-26
lines changed

6 files changed

+45
-26
lines changed

app/agent/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from abc import ABC, abstractmethod
22
from contextlib import asynccontextmanager
3-
from typing import List, Literal, Optional
3+
from typing import List, Optional
44

55
from pydantic import BaseModel, Field, model_validator
66

77
from app.llm import LLM
88
from app.logger import logger
9-
from app.schema import AgentState, Memory, Message
9+
from app.schema import AgentState, Memory, Message, ROLE_TYPE
1010

1111

1212
class BaseAgent(BaseModel, ABC):
@@ -82,7 +82,7 @@ async def state_context(self, new_state: AgentState):
8282

8383
def update_memory(
8484
self,
85-
role: Literal["user", "system", "assistant", "tool"],
85+
role: ROLE_TYPE, # type: ignore
8686
content: str,
8787
**kwargs,
8888
) -> None:

app/agent/planning.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import time
2-
from typing import Dict, List, Literal, Optional
2+
from typing import Dict, List, Optional
33

44
from pydantic import Field, model_validator
55

66
from app.agent.toolcall import ToolCallAgent
77
from app.logger import logger
88
from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT
9-
from app.schema import Message, ToolCall
9+
from app.schema import Message, TOOL_CHOICE_TYPE, ToolCall, ToolChoice
1010
from app.tool import PlanningTool, Terminate, ToolCollection
1111

1212

@@ -27,7 +27,7 @@ class PlanningAgent(ToolCallAgent):
2727
available_tools: ToolCollection = Field(
2828
default_factory=lambda: ToolCollection(PlanningTool(), Terminate())
2929
)
30-
tool_choices: Literal["none", "auto", "required"] = "auto"
30+
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
3131
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
3232

3333
tool_calls: List[ToolCall] = Field(default_factory=list)
@@ -212,7 +212,7 @@ async def create_initial_plan(self, request: str) -> None:
212212
messages=messages,
213213
system_msgs=[Message.system_message(self.system_prompt)],
214214
tools=self.available_tools.to_params(),
215-
tool_choice="required",
215+
tool_choice=ToolChoice.REQUIRED,
216216
)
217217
assistant_msg = Message.from_tool_calls(
218218
content=response.content, tool_calls=response.tool_calls

app/agent/toolcall.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
2+
23
from typing import Any, List, Literal, Optional, Union
34

45
from pydantic import Field
56

67
from app.agent.react import ReActAgent
78
from app.logger import logger
89
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
9-
from app.schema import AgentState, Message, ToolCall
10+
from app.schema import AgentState, Message, ToolCall, TOOL_CHOICE_TYPE, ToolChoice
1011
from app.tool import CreateChatCompletion, Terminate, ToolCollection
1112

1213

@@ -25,7 +26,7 @@ class ToolCallAgent(ReActAgent):
2526
available_tools: ToolCollection = ToolCollection(
2627
CreateChatCompletion(), Terminate()
2728
)
28-
tool_choices: Literal["none", "auto", "required"] = "auto"
29+
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
2930
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
3031

3132
tool_calls: List[ToolCall] = Field(default_factory=list)
@@ -62,7 +63,7 @@ async def think(self) -> bool:
6263

6364
try:
6465
# Handle different tool_choices modes
65-
if self.tool_choices == "none":
66+
if self.tool_choices == ToolChoice.NONE:
6667
if response.tool_calls:
6768
logger.warning(
6869
f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
@@ -82,11 +83,11 @@ async def think(self) -> bool:
8283
)
8384
self.memory.add_message(assistant_msg)
8485

85-
if self.tool_choices == "required" and not self.tool_calls:
86+
if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls:
8687
return True # Will be handled in act()
8788

8889
# For 'auto' mode, continue with content if no commands but content exists
89-
if self.tool_choices == "auto" and not self.tool_calls:
90+
if self.tool_choices == ToolChoice.AUTO and not self.tool_calls:
9091
return bool(response.content)
9192

9293
return bool(self.tool_calls)
@@ -102,7 +103,7 @@ async def think(self) -> bool:
102103
async def act(self) -> str:
103104
"""Execute tool calls and handle their results"""
104105
if not self.tool_calls:
105-
if self.tool_choices == "required":
106+
if self.tool_choices == ToolChoice.REQUIRED:
106107
raise ValueError(TOOL_CALL_REQUIRED)
107108

108109
# Return last message content if no tool calls

app/flow/planning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from app.flow.base import BaseFlow, PlanStepStatus
99
from app.llm import LLM
1010
from app.logger import logger
11-
from app.schema import AgentState, Message
11+
from app.schema import AgentState, Message, ToolChoice
1212
from app.tool import PlanningTool
1313

1414

@@ -124,7 +124,7 @@ async def _create_initial_plan(self, request: str) -> None:
124124
messages=[user_message],
125125
system_msgs=[system_message],
126126
tools=[self.planning_tool.to_param()],
127-
tool_choice="required",
127+
tool_choice=ToolChoice.REQUIRED,
128128
)
129129

130130
# Process tool calls if present

app/llm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Literal, Optional, Union
1+
from typing import Dict, List, Optional, Union
22

33
from openai import (
44
APIError,
@@ -12,7 +12,7 @@
1212

1313
from app.config import LLMSettings, config
1414
from app.logger import logger # Assuming a logger is set up in your app
15-
from app.schema import Message
15+
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
1616

1717

1818
class LLM:
@@ -88,7 +88,7 @@ def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
8888

8989
# Validate all messages have required fields
9090
for msg in formatted_messages:
91-
if msg["role"] not in ["system", "user", "assistant", "tool"]:
91+
if msg["role"] not in ROLE_VALUES:
9292
raise ValueError(f"Invalid role: {msg['role']}")
9393
if "content" not in msg and "tool_calls" not in msg:
9494
raise ValueError(
@@ -187,7 +187,7 @@ async def ask_tool(
187187
system_msgs: Optional[List[Union[dict, Message]]] = None,
188188
timeout: int = 300,
189189
tools: Optional[List[dict]] = None,
190-
tool_choice: Literal["none", "auto", "required"] = "auto",
190+
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
191191
temperature: Optional[float] = None,
192192
**kwargs,
193193
):
@@ -213,7 +213,7 @@ async def ask_tool(
213213
"""
214214
try:
215215
# Validate tool_choice
216-
if tool_choice not in ["none", "auto", "required"]:
216+
if tool_choice not in TOOL_CHOICE_VALUES:
217217
raise ValueError(f"Invalid tool_choice: {tool_choice}")
218218

219219
# Format messages

app/schema.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@
33

44
from pydantic import BaseModel, Field
55

6+
class Role(str, Enum):
7+
"""Message role options"""
8+
SYSTEM = "system"
9+
USER = "user"
10+
ASSISTANT = "assistant"
11+
TOOL = "tool"
12+
13+
ROLE_VALUES = tuple(role.value for role in Role)
14+
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
15+
16+
class ToolChoice(str, Enum):
17+
"""Tool choice options"""
18+
NONE = "none"
19+
AUTO = "auto"
20+
REQUIRED = "required"
21+
22+
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
23+
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
624

725
class AgentState(str, Enum):
826
"""Agent execution states"""
@@ -29,7 +47,7 @@ class ToolCall(BaseModel):
2947
class Message(BaseModel):
3048
"""Represents a chat message in the conversation"""
3149

32-
role: Literal["system", "user", "assistant", "tool"] = Field(...)
50+
role: ROLE_TYPE = Field(...) # type: ignore
3351
content: Optional[str] = Field(default=None)
3452
tool_calls: Optional[List[ToolCall]] = Field(default=None)
3553
name: Optional[str] = Field(default=None)
@@ -71,22 +89,22 @@ def to_dict(self) -> dict:
7189
@classmethod
7290
def user_message(cls, content: str) -> "Message":
7391
"""Create a user message"""
74-
return cls(role="user", content=content)
92+
return cls(role=Role.USER, content=content)
7593

7694
@classmethod
7795
def system_message(cls, content: str) -> "Message":
7896
"""Create a system message"""
79-
return cls(role="system", content=content)
97+
return cls(role=Role.SYSTEM, content=content)
8098

8199
@classmethod
82100
def assistant_message(cls, content: Optional[str] = None) -> "Message":
83101
"""Create an assistant message"""
84-
return cls(role="assistant", content=content)
102+
return cls(role=Role.ASSISTANT, content=content)
85103

86104
@classmethod
87105
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
88106
"""Create a tool message"""
89-
return cls(role="tool", content=content, name=name, tool_call_id=tool_call_id)
107+
return cls(role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id)
90108

91109
@classmethod
92110
def from_tool_calls(
@@ -103,7 +121,7 @@ def from_tool_calls(
103121
for call in tool_calls
104122
]
105123
return cls(
106-
role="assistant", content=content, tool_calls=formatted_calls, **kwargs
124+
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
107125
)
108126

109127

0 commit comments

Comments
 (0)