1
1
from abc import ABC , abstractmethod
2
- from contextlib import asynccontextmanager
3
2
from typing import List , Literal , Optional
4
3
5
4
from pydantic import BaseModel , Field , model_validator
6
5
7
6
from app .llm import LLM
8
7
from app .logger import logger
9
- from app .schema import AgentState , Memory , Message
8
+ from app .schema import Message
9
+ from .components .state_manager import StateManager , AgentState
10
+ from .components .memory_manager import MemoryManager
10
11
11
12
12
13
class BaseAgent (BaseModel , ABC ):
@@ -30,16 +31,14 @@ class BaseAgent(BaseModel, ABC):
30
31
31
32
# Dependencies
32
33
llm : LLM = Field (default_factory = LLM , description = "Language model instance" )
33
- memory : Memory = Field (default_factory = Memory , description = "Agent's memory store" )
34
- state : AgentState = Field (
35
- default = AgentState .IDLE , description = "Current agent state"
36
- )
34
+ memory : MemoryManager = Field (default_factory = MemoryManager , description = "Agent's memory store" )
35
+ state_manager : StateManager = Field (default_factory = StateManager , description = "Agent's state manager" )
37
36
38
37
# Execution control
39
38
max_steps : int = Field (default = 10 , description = "Maximum steps before termination" )
40
39
current_step : int = Field (default = 0 , description = "Current step in execution" )
41
40
42
- duplicate_threshold : int = 2
41
+
43
42
44
43
class Config :
45
44
arbitrary_types_allowed = True
@@ -50,35 +49,21 @@ def initialize_agent(self) -> "BaseAgent":
50
49
"""Initialize agent with default settings if not provided."""
51
50
if self .llm is None or not isinstance (self .llm , LLM ):
52
51
self .llm = LLM (config_name = self .name .lower ())
53
- if not isinstance (self .memory , Memory ):
54
- self .memory = Memory ()
52
+ if not isinstance (self .memory , MemoryManager ):
53
+ self .memory = MemoryManager ()
54
+ if not isinstance (self .state_manager , StateManager ):
55
+ self .state_manager = StateManager ()
55
56
return self
56
57
57
- @asynccontextmanager
58
- async def state_context (self , new_state : AgentState ):
59
- """Context manager for safe agent state transitions.
60
-
61
- Args:
62
- new_state: The state to transition to during the context.
63
-
64
- Yields:
65
- None: Allows execution within the new state.
58
+ @property
59
+ def state (self ) -> AgentState :
60
+ """Get the current agent state."""
61
+ return self .state_manager .state
66
62
67
- Raises:
68
- ValueError: If the new_state is invalid.
69
- """
70
- if not isinstance (new_state , AgentState ):
71
- raise ValueError (f"Invalid state: { new_state } " )
72
-
73
- previous_state = self .state
74
- self .state = new_state
75
- try :
76
- yield
77
- except Exception as e :
78
- self .state = AgentState .ERROR # Transition to ERROR on failure
79
- raise e
80
- finally :
81
- self .state = previous_state # Revert to previous state
63
+ @state .setter
64
+ def state (self , new_state : AgentState ):
65
+ """Set the agent state."""
66
+ self .state_manager .state = new_state
82
67
83
68
def update_memory (
84
69
self ,
@@ -129,7 +114,7 @@ async def run(self, request: Optional[str] = None) -> str:
129
114
self .update_memory ("user" , request )
130
115
131
116
results : List [str ] = []
132
- async with self .state_context (AgentState .RUNNING ):
117
+ async with self .state_manager . state_context (AgentState .RUNNING ):
133
118
while (
134
119
self .current_step < self .max_steps and self .state != AgentState .FINISHED
135
120
):
@@ -162,24 +147,6 @@ def handle_stuck_state(self):
162
147
self .next_step_prompt = f"{ stuck_prompt } \n { self .next_step_prompt } "
163
148
logger .warning (f"Agent detected stuck state. Added prompt: { stuck_prompt } " )
164
149
165
- def is_stuck (self ) -> bool :
166
- """Check if the agent is stuck in a loop by detecting duplicate content"""
167
- if len (self .memory .messages ) < 2 :
168
- return False
169
-
170
- last_message = self .memory .messages [- 1 ]
171
- if not last_message .content :
172
- return False
173
-
174
- # Count identical content occurrences
175
- duplicate_count = sum (
176
- 1
177
- for msg in reversed (self .memory .messages [:- 1 ])
178
- if msg .role == "assistant" and msg .content == last_message .content
179
- )
180
-
181
- return duplicate_count >= self .duplicate_threshold
182
-
183
150
@property
184
151
def messages (self ) -> List [Message ]:
185
152
"""Retrieve a list of messages from the agent's memory."""
@@ -189,3 +156,7 @@ def messages(self) -> List[Message]:
189
156
def messages (self , value : List [Message ]):
190
157
"""Set the list of messages in the agent's memory."""
191
158
self .memory .messages = value
159
+
160
+ def is_stuck (self ) -> bool :
161
+ """Check if the agent is stuck in a loop by detecting duplicate content."""
162
+ return self .memory .is_stuck ()
0 commit comments