first commit
This commit is contained in:
171
MeetSpot/app/agent/base.py
Normal file
171
MeetSpot/app/agent/base.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Agent 基类 - 参考 OpenManus BaseAgent 设计"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Memory, Message, ROLE_TYPE
|
||||
|
||||
|
||||
class BaseAgent(BaseModel, ABC):
|
||||
"""Agent 基类
|
||||
|
||||
提供基础的状态管理、记忆管理和执行循环。
|
||||
子类需要实现 step() 方法来定义具体行为。
|
||||
"""
|
||||
|
||||
# 核心属性
|
||||
name: str = Field(default="BaseAgent", description="Agent 名称")
|
||||
description: Optional[str] = Field(default=None, description="Agent 描述")
|
||||
|
||||
# 提示词
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
|
||||
next_step_prompt: Optional[str] = Field(default=None, description="下一步提示词")
|
||||
|
||||
# 依赖
|
||||
llm: Optional[LLM] = Field(default=None, description="LLM 实例")
|
||||
memory: Memory = Field(default_factory=Memory, description="Agent 记忆")
|
||||
state: AgentState = Field(default=AgentState.IDLE, description="当前状态")
|
||||
|
||||
# 执行控制
|
||||
max_steps: int = Field(default=10, description="最大执行步数")
|
||||
current_step: int = Field(default=0, description="当前步数")
|
||||
|
||||
# 重复检测阈值
|
||||
duplicate_threshold: int = 2
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_agent(self) -> "BaseAgent":
|
||||
"""初始化 Agent"""
|
||||
if self.llm is None:
|
||||
try:
|
||||
self.llm = LLM()
|
||||
except Exception as e:
|
||||
logger.warning(f"无法初始化 LLM: {e}")
|
||||
if not isinstance(self.memory, Memory):
|
||||
self.memory = Memory()
|
||||
return self
|
||||
|
||||
@asynccontextmanager
|
||||
async def state_context(self, new_state: AgentState):
|
||||
"""状态上下文管理器,用于安全的状态转换"""
|
||||
if not isinstance(new_state, AgentState):
|
||||
raise ValueError(f"无效状态: {new_state}")
|
||||
|
||||
previous_state = self.state
|
||||
self.state = new_state
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
self.state = AgentState.ERROR
|
||||
raise e
|
||||
finally:
|
||||
self.state = previous_state
|
||||
|
||||
def update_memory(
|
||||
self,
|
||||
role: ROLE_TYPE,
|
||||
content: str,
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""添加消息到记忆"""
|
||||
message_map = {
|
||||
"user": Message.user_message,
|
||||
"system": Message.system_message,
|
||||
"assistant": Message.assistant_message,
|
||||
"tool": lambda content, **kw: Message.tool_message(content, **kw),
|
||||
}
|
||||
|
||||
if role not in message_map:
|
||||
raise ValueError(f"不支持的消息角色: {role}")
|
||||
|
||||
if role == "tool":
|
||||
self.memory.add_message(message_map[role](content, **kwargs))
|
||||
else:
|
||||
self.memory.add_message(message_map[role](content, base64_image=base64_image))
|
||||
|
||||
async def run(self, request: Optional[str] = None) -> str:
|
||||
"""执行 Agent 主循环
|
||||
|
||||
Args:
|
||||
request: 可选的初始用户请求
|
||||
|
||||
Returns:
|
||||
执行结果摘要
|
||||
"""
|
||||
if self.state != AgentState.IDLE:
|
||||
raise RuntimeError(f"无法从状态 {self.state} 启动 Agent")
|
||||
|
||||
if request:
|
||||
self.update_memory("user", request)
|
||||
|
||||
results: List[str] = []
|
||||
async with self.state_context(AgentState.RUNNING):
|
||||
while (
|
||||
self.current_step < self.max_steps
|
||||
and self.state != AgentState.FINISHED
|
||||
):
|
||||
self.current_step += 1
|
||||
logger.info(f"执行步骤 {self.current_step}/{self.max_steps}")
|
||||
step_result = await self.step()
|
||||
|
||||
# 检测卡住状态
|
||||
if self.is_stuck():
|
||||
self.handle_stuck_state()
|
||||
|
||||
results.append(f"Step {self.current_step}: {step_result}")
|
||||
|
||||
if self.current_step >= self.max_steps:
|
||||
self.current_step = 0
|
||||
self.state = AgentState.IDLE
|
||||
results.append(f"已终止: 达到最大步数 ({self.max_steps})")
|
||||
|
||||
return "\n".join(results) if results else "未执行任何步骤"
|
||||
|
||||
@abstractmethod
|
||||
async def step(self) -> str:
|
||||
"""执行单步操作 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
def handle_stuck_state(self):
|
||||
"""处理卡住状态"""
|
||||
stuck_prompt = "检测到重复响应。请考虑新策略,避免重复已尝试过的无效路径。"
|
||||
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt or ''}"
|
||||
logger.warning(f"Agent 检测到卡住状态,已添加提示")
|
||||
|
||||
def is_stuck(self) -> bool:
|
||||
"""检测是否陷入循环"""
|
||||
if len(self.memory.messages) < 2:
|
||||
return False
|
||||
|
||||
last_message = self.memory.messages[-1]
|
||||
if not last_message.content:
|
||||
return False
|
||||
|
||||
# 统计相同内容出现次数
|
||||
duplicate_count = sum(
|
||||
1
|
||||
for msg in reversed(self.memory.messages[:-1])
|
||||
if msg.role == "assistant" and msg.content == last_message.content
|
||||
)
|
||||
|
||||
return duplicate_count >= self.duplicate_threshold
|
||||
|
||||
@property
|
||||
def messages(self) -> List[Message]:
|
||||
"""获取记忆中的消息列表"""
|
||||
return self.memory.messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value: List[Message]):
|
||||
"""设置记忆中的消息列表"""
|
||||
self.memory.messages = value
|
||||
Reference in New Issue
Block a user