first commit
This commit is contained in:
20
MeetSpot/app/agent/__init__.py
Normal file
20
MeetSpot/app/agent/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""MeetSpot Agent Module - 基于 OpenManus 架构的智能推荐 Agent"""
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.agent.meetspot_agent import MeetSpotAgent, create_meetspot_agent
|
||||
from app.agent.tools import (
|
||||
CalculateCenterTool,
|
||||
GeocodeTool,
|
||||
GenerateRecommendationTool,
|
||||
SearchPOITool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"MeetSpotAgent",
|
||||
"create_meetspot_agent",
|
||||
"GeocodeTool",
|
||||
"CalculateCenterTool",
|
||||
"SearchPOITool",
|
||||
"GenerateRecommendationTool",
|
||||
]
|
||||
171
MeetSpot/app/agent/base.py
Normal file
171
MeetSpot/app/agent/base.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Agent 基类 - 参考 OpenManus BaseAgent 设计"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Memory, Message, ROLE_TYPE
|
||||
|
||||
|
||||
class BaseAgent(BaseModel, ABC):
|
||||
"""Agent 基类
|
||||
|
||||
提供基础的状态管理、记忆管理和执行循环。
|
||||
子类需要实现 step() 方法来定义具体行为。
|
||||
"""
|
||||
|
||||
# 核心属性
|
||||
name: str = Field(default="BaseAgent", description="Agent 名称")
|
||||
description: Optional[str] = Field(default=None, description="Agent 描述")
|
||||
|
||||
# 提示词
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
|
||||
next_step_prompt: Optional[str] = Field(default=None, description="下一步提示词")
|
||||
|
||||
# 依赖
|
||||
llm: Optional[LLM] = Field(default=None, description="LLM 实例")
|
||||
memory: Memory = Field(default_factory=Memory, description="Agent 记忆")
|
||||
state: AgentState = Field(default=AgentState.IDLE, description="当前状态")
|
||||
|
||||
# 执行控制
|
||||
max_steps: int = Field(default=10, description="最大执行步数")
|
||||
current_step: int = Field(default=0, description="当前步数")
|
||||
|
||||
# 重复检测阈值
|
||||
duplicate_threshold: int = 2
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_agent(self) -> "BaseAgent":
|
||||
"""初始化 Agent"""
|
||||
if self.llm is None:
|
||||
try:
|
||||
self.llm = LLM()
|
||||
except Exception as e:
|
||||
logger.warning(f"无法初始化 LLM: {e}")
|
||||
if not isinstance(self.memory, Memory):
|
||||
self.memory = Memory()
|
||||
return self
|
||||
|
||||
@asynccontextmanager
|
||||
async def state_context(self, new_state: AgentState):
|
||||
"""状态上下文管理器,用于安全的状态转换"""
|
||||
if not isinstance(new_state, AgentState):
|
||||
raise ValueError(f"无效状态: {new_state}")
|
||||
|
||||
previous_state = self.state
|
||||
self.state = new_state
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
self.state = AgentState.ERROR
|
||||
raise e
|
||||
finally:
|
||||
self.state = previous_state
|
||||
|
||||
def update_memory(
|
||||
self,
|
||||
role: ROLE_TYPE,
|
||||
content: str,
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""添加消息到记忆"""
|
||||
message_map = {
|
||||
"user": Message.user_message,
|
||||
"system": Message.system_message,
|
||||
"assistant": Message.assistant_message,
|
||||
"tool": lambda content, **kw: Message.tool_message(content, **kw),
|
||||
}
|
||||
|
||||
if role not in message_map:
|
||||
raise ValueError(f"不支持的消息角色: {role}")
|
||||
|
||||
if role == "tool":
|
||||
self.memory.add_message(message_map[role](content, **kwargs))
|
||||
else:
|
||||
self.memory.add_message(message_map[role](content, base64_image=base64_image))
|
||||
|
||||
async def run(self, request: Optional[str] = None) -> str:
|
||||
"""执行 Agent 主循环
|
||||
|
||||
Args:
|
||||
request: 可选的初始用户请求
|
||||
|
||||
Returns:
|
||||
执行结果摘要
|
||||
"""
|
||||
if self.state != AgentState.IDLE:
|
||||
raise RuntimeError(f"无法从状态 {self.state} 启动 Agent")
|
||||
|
||||
if request:
|
||||
self.update_memory("user", request)
|
||||
|
||||
results: List[str] = []
|
||||
async with self.state_context(AgentState.RUNNING):
|
||||
while (
|
||||
self.current_step < self.max_steps
|
||||
and self.state != AgentState.FINISHED
|
||||
):
|
||||
self.current_step += 1
|
||||
logger.info(f"执行步骤 {self.current_step}/{self.max_steps}")
|
||||
step_result = await self.step()
|
||||
|
||||
# 检测卡住状态
|
||||
if self.is_stuck():
|
||||
self.handle_stuck_state()
|
||||
|
||||
results.append(f"Step {self.current_step}: {step_result}")
|
||||
|
||||
if self.current_step >= self.max_steps:
|
||||
self.current_step = 0
|
||||
self.state = AgentState.IDLE
|
||||
results.append(f"已终止: 达到最大步数 ({self.max_steps})")
|
||||
|
||||
return "\n".join(results) if results else "未执行任何步骤"
|
||||
|
||||
@abstractmethod
|
||||
async def step(self) -> str:
|
||||
"""执行单步操作 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
def handle_stuck_state(self):
|
||||
"""处理卡住状态"""
|
||||
stuck_prompt = "检测到重复响应。请考虑新策略,避免重复已尝试过的无效路径。"
|
||||
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt or ''}"
|
||||
logger.warning(f"Agent 检测到卡住状态,已添加提示")
|
||||
|
||||
def is_stuck(self) -> bool:
|
||||
"""检测是否陷入循环"""
|
||||
if len(self.memory.messages) < 2:
|
||||
return False
|
||||
|
||||
last_message = self.memory.messages[-1]
|
||||
if not last_message.content:
|
||||
return False
|
||||
|
||||
# 统计相同内容出现次数
|
||||
duplicate_count = sum(
|
||||
1
|
||||
for msg in reversed(self.memory.messages[:-1])
|
||||
if msg.role == "assistant" and msg.content == last_message.content
|
||||
)
|
||||
|
||||
return duplicate_count >= self.duplicate_threshold
|
||||
|
||||
@property
|
||||
def messages(self) -> List[Message]:
|
||||
"""获取记忆中的消息列表"""
|
||||
return self.memory.messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value: List[Message]):
|
||||
"""设置记忆中的消息列表"""
|
||||
self.memory.messages = value
|
||||
361
MeetSpot/app/agent/meetspot_agent.py
Normal file
361
MeetSpot/app/agent/meetspot_agent.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""MeetSpotAgent - 智能会面地点推荐 Agent
|
||||
|
||||
基于 ReAct 模式实现的智能推荐代理,通过工具调用完成地点推荐任务。
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.agent.tools import (
|
||||
CalculateCenterTool,
|
||||
GeocodeTool,
|
||||
GenerateRecommendationTool,
|
||||
SearchPOITool,
|
||||
)
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Message
|
||||
from app.tool.tool_collection import ToolCollection
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """你是 MeetSpot 智能会面助手,帮助用户找到最佳会面地点。
|
||||
|
||||
## 你的能力
|
||||
你可以使用以下工具来完成任务:
|
||||
|
||||
1. **geocode** - 地理编码
|
||||
- 将地址转换为经纬度坐标
|
||||
- 支持大学简称(北大、清华)、地标、商圈等
|
||||
- 返回坐标和格式化地址
|
||||
|
||||
2. **calculate_center** - 计算中心点
|
||||
- 计算多个位置的几何中心
|
||||
- 作为最佳会面位置的参考点
|
||||
- 使用球面几何确保精确
|
||||
|
||||
3. **search_poi** - 搜索场所
|
||||
- 在中心点附近搜索各类场所
|
||||
- 支持咖啡馆、餐厅、图书馆、健身房等
|
||||
- 返回名称、地址、评分、距离等
|
||||
|
||||
4. **generate_recommendation** - 生成推荐
|
||||
- 分析搜索结果
|
||||
- 根据评分、距离、用户需求排序
|
||||
- 生成个性化推荐理由
|
||||
|
||||
## 工作流程
|
||||
请按以下步骤执行:
|
||||
|
||||
1. **理解任务** - 分析用户提供的位置和需求
|
||||
2. **地理编码** - 依次对每个地址使用 geocode 获取坐标
|
||||
3. **计算中心** - 使用 calculate_center 计算最佳会面点
|
||||
4. **搜索场所** - 使用 search_poi 在中心点附近搜索
|
||||
5. **生成推荐** - 使用 generate_recommendation 生成最终推荐
|
||||
|
||||
## 输出要求
|
||||
- 推荐 3-5 个最佳场所
|
||||
- 为每个场所说明推荐理由(距离、评分、特色)
|
||||
- 考虑用户的特殊需求(停车、安静、商务等)
|
||||
- 使用中文回复
|
||||
|
||||
## 注意事项
|
||||
- 确保在调用工具前已获取所有必要参数
|
||||
- 如果地址解析失败,提供具体的错误信息和建议
|
||||
- 如果搜索无结果,尝试调整搜索关键词或扩大半径
|
||||
"""
|
||||
|
||||
|
||||
class MeetSpotAgent(BaseAgent):
|
||||
"""MeetSpot 智能会面推荐 Agent
|
||||
|
||||
基于 ReAct 模式的智能代理,通过 think() -> act() 循环完成推荐任务。
|
||||
"""
|
||||
|
||||
name: str = "MeetSpotAgent"
|
||||
description: str = "智能会面地点推荐助手"
|
||||
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
next_step_prompt: str = "请继续执行下一步,或者如果已完成所有工具调用,请生成最终推荐结果。"
|
||||
|
||||
max_steps: int = 15 # 允许更多步骤以完成复杂任务
|
||||
|
||||
# 工具集合
|
||||
available_tools: ToolCollection = Field(default=None)
|
||||
|
||||
# 当前工具调用
|
||||
tool_calls: List[Any] = Field(default_factory=list)
|
||||
|
||||
# 存储中间结果
|
||||
geocode_results: List[Dict] = Field(default_factory=list)
|
||||
center_point: Optional[Dict] = None
|
||||
search_results: List[Dict] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
# 初始化工具集合
|
||||
if self.available_tools is None:
|
||||
self.available_tools = ToolCollection(
|
||||
GeocodeTool(),
|
||||
CalculateCenterTool(),
|
||||
SearchPOITool(),
|
||||
GenerateRecommendationTool()
|
||||
)
|
||||
|
||||
async def step(self) -> str:
|
||||
"""执行一步: think + act
|
||||
|
||||
Returns:
|
||||
步骤执行结果的描述
|
||||
"""
|
||||
# Think: 决定下一步行动
|
||||
should_continue = await self.think()
|
||||
|
||||
if not should_continue:
|
||||
self.state = AgentState.FINISHED
|
||||
return "任务完成"
|
||||
|
||||
# Act: 执行工具调用
|
||||
result = await self.act()
|
||||
return result
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""思考阶段 - 决定下一步行动
|
||||
|
||||
使用 LLM 分析当前状态,决定是否需要调用工具以及调用哪个工具。
|
||||
|
||||
Returns:
|
||||
是否需要继续执行
|
||||
"""
|
||||
# 构建消息
|
||||
messages = self.memory.messages.copy()
|
||||
|
||||
# 添加提示引导下一步
|
||||
if self.next_step_prompt and self.current_step > 1:
|
||||
messages.append(Message.user_message(self.next_step_prompt))
|
||||
|
||||
# 调用 LLM 获取响应
|
||||
response = await self.llm.ask_tool(
|
||||
messages=messages,
|
||||
system_msgs=[Message.system_message(self.system_prompt)],
|
||||
tools=self.available_tools.to_params(),
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
if response is None:
|
||||
logger.warning("LLM 返回空响应")
|
||||
return False
|
||||
|
||||
# 提取工具调用和内容
|
||||
self.tool_calls = response.tool_calls or []
|
||||
content = response.content or ""
|
||||
|
||||
logger.info(f"Agent 思考: {content[:200]}..." if len(content) > 200 else f"Agent 思考: {content}")
|
||||
|
||||
if self.tool_calls:
|
||||
tool_names = [tc.function.name for tc in self.tool_calls]
|
||||
logger.info(f"选择工具: {tool_names}")
|
||||
|
||||
# 保存 assistant 消息到记忆
|
||||
if self.tool_calls:
|
||||
# 带工具调用的消息
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
}
|
||||
for tc in self.tool_calls
|
||||
]
|
||||
self.memory.add_message(Message(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls_data
|
||||
))
|
||||
elif content:
|
||||
# 纯文本消息(可能是最终回复)
|
||||
self.memory.add_message(Message.assistant_message(content))
|
||||
# 如果没有工具调用且有内容,可能是最终回复
|
||||
if "推荐" in content and len(content) > 100:
|
||||
return False # 结束循环
|
||||
|
||||
return bool(self.tool_calls) or bool(content)
|
||||
|
||||
async def act(self) -> str:
|
||||
"""行动阶段 - 执行工具调用
|
||||
|
||||
执行思考阶段决定的工具调用,并将结果添加到记忆。
|
||||
|
||||
Returns:
|
||||
工具执行结果的描述
|
||||
"""
|
||||
if not self.tool_calls:
|
||||
# 没有工具调用,返回最后一条消息的内容
|
||||
return self.memory.messages[-1].content or "无操作"
|
||||
|
||||
results = []
|
||||
for call in self.tool_calls:
|
||||
tool_name = call.function.name
|
||||
tool_args = call.function.arguments
|
||||
|
||||
try:
|
||||
# 解析参数
|
||||
args = json.loads(tool_args) if isinstance(tool_args, str) else tool_args
|
||||
|
||||
# 执行工具
|
||||
logger.info(f"执行工具: {tool_name}, 参数: {args}")
|
||||
result = await self.available_tools.execute(name=tool_name, tool_input=args)
|
||||
|
||||
# 保存中间结果
|
||||
self._save_intermediate_result(tool_name, result, args)
|
||||
|
||||
# 将工具结果添加到记忆
|
||||
result_str = str(result)
|
||||
self.memory.add_message(Message.tool_message(
|
||||
content=result_str,
|
||||
tool_call_id=call.id,
|
||||
name=tool_name
|
||||
))
|
||||
|
||||
logger.info(f"工具 {tool_name} 完成")
|
||||
results.append(f"{tool_name}: 成功")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"工具执行失败: {str(e)}"
|
||||
logger.error(f"{tool_name} {error_msg}")
|
||||
|
||||
# 添加错误消息到记忆
|
||||
self.memory.add_message(Message.tool_message(
|
||||
content=error_msg,
|
||||
tool_call_id=call.id,
|
||||
name=tool_name
|
||||
))
|
||||
results.append(f"{tool_name}: 失败 - {str(e)}")
|
||||
|
||||
return " | ".join(results)
|
||||
|
||||
def _save_intermediate_result(self, tool_name: str, result: Any, args: Dict) -> None:
|
||||
"""保存工具执行的中间结果
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
result: 工具执行结果
|
||||
args: 工具参数
|
||||
"""
|
||||
try:
|
||||
# 解析结果
|
||||
if hasattr(result, 'output') and result.output:
|
||||
data = json.loads(result.output) if isinstance(result.output, str) else result.output
|
||||
else:
|
||||
return
|
||||
|
||||
if tool_name == "geocode" and data:
|
||||
self.geocode_results.append({
|
||||
"address": args.get("address", ""),
|
||||
"lng": data.get("lng"),
|
||||
"lat": data.get("lat"),
|
||||
"formatted_address": data.get("formatted_address", "")
|
||||
})
|
||||
|
||||
elif tool_name == "calculate_center" and data:
|
||||
self.center_point = data.get("center")
|
||||
|
||||
elif tool_name == "search_poi" and data:
|
||||
places = data.get("places", [])
|
||||
self.search_results.extend(places)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"保存中间结果时出错: {e}")
|
||||
|
||||
async def recommend(
|
||||
self,
|
||||
locations: List[str],
|
||||
keywords: str = "咖啡馆",
|
||||
requirements: str = ""
|
||||
) -> Dict:
|
||||
"""执行推荐任务
|
||||
|
||||
这是 Agent 的主要入口方法,接收用户输入并返回推荐结果。
|
||||
|
||||
Args:
|
||||
locations: 参与者位置列表
|
||||
keywords: 搜索关键词(场所类型)
|
||||
requirements: 用户特殊需求
|
||||
|
||||
Returns:
|
||||
包含推荐结果的字典
|
||||
"""
|
||||
# 重置状态
|
||||
self.geocode_results = []
|
||||
self.center_point = None
|
||||
self.search_results = []
|
||||
self.current_step = 0
|
||||
self.state = AgentState.IDLE
|
||||
self.memory.clear()
|
||||
|
||||
# 构建任务描述
|
||||
locations_str = "、".join(locations)
|
||||
task = f"""请帮我找到适合会面的地点:
|
||||
|
||||
**参与者位置**:{locations_str}
|
||||
**想找的场所类型**:{keywords}
|
||||
**特殊需求**:{requirements or "无特殊需求"}
|
||||
|
||||
请按照工作流程执行:
|
||||
1. 先用 geocode 工具获取每个位置的坐标
|
||||
2. 用 calculate_center 计算中心点
|
||||
3. 用 search_poi 搜索附近的 {keywords}
|
||||
4. 用 generate_recommendation 生成推荐
|
||||
|
||||
最后请用中文总结推荐结果。"""
|
||||
|
||||
# 执行任务
|
||||
result = await self.run(task)
|
||||
|
||||
# 格式化返回结果
|
||||
return self._format_result(result)
|
||||
|
||||
def _format_result(self, raw_result: str) -> Dict:
|
||||
"""格式化最终结果
|
||||
|
||||
Args:
|
||||
raw_result: Agent 执行的原始结果
|
||||
|
||||
Returns:
|
||||
格式化的结果字典
|
||||
"""
|
||||
# 获取最后一条 assistant 消息作为最终推荐
|
||||
final_recommendation = ""
|
||||
for msg in reversed(self.memory.messages):
|
||||
if msg.role == "assistant" and msg.content:
|
||||
final_recommendation = msg.content
|
||||
break
|
||||
|
||||
return {
|
||||
"success": self.state == AgentState.IDLE, # IDLE 表示正常完成
|
||||
"recommendation": final_recommendation,
|
||||
"geocode_results": self.geocode_results,
|
||||
"center_point": self.center_point,
|
||||
"search_results": self.search_results[:10], # 限制返回数量
|
||||
"steps_executed": self.current_step,
|
||||
"raw_output": raw_result
|
||||
}
|
||||
|
||||
|
||||
# 创建默认 Agent 实例的工厂函数
|
||||
def create_meetspot_agent() -> MeetSpotAgent:
|
||||
"""创建 MeetSpotAgent 实例
|
||||
|
||||
Returns:
|
||||
配置好的 MeetSpotAgent 实例
|
||||
"""
|
||||
return MeetSpotAgent()
|
||||
514
MeetSpot/app/agent/tools.py
Normal file
514
MeetSpot/app/agent/tools.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""MeetSpot Agent 工具集 - 封装推荐系统的核心功能"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.tool.base import BaseTool, ToolResult
|
||||
from app.logger import logger
|
||||
|
||||
|
||||
class GeocodeTool(BaseTool):
|
||||
"""地理编码工具 - 将地址转换为经纬度坐标"""
|
||||
|
||||
name: str = "geocode"
|
||||
description: str = """将地址或地点名称转换为经纬度坐标。
|
||||
支持各种地址格式:
|
||||
- 完整地址:'北京市海淀区中关村大街1号'
|
||||
- 大学简称:'北大'、'清华'、'复旦'(自动扩展为完整地址)
|
||||
- 知名地标:'天安门'、'外滩'、'广州塔'
|
||||
- 商圈区域:'三里屯'、'王府井'
|
||||
|
||||
返回地址的经纬度坐标和格式化地址。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "string",
|
||||
"description": "地址或地点名称,如'北京大学'、'上海市浦东新区陆家嘴'"
|
||||
}
|
||||
},
|
||||
"required": ["address"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
# 确保 API key 已设置
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(self, address: str) -> ToolResult:
|
||||
"""执行地理编码"""
|
||||
try:
|
||||
recommender = self._get_recommender()
|
||||
result = await recommender._geocode(address)
|
||||
|
||||
if result:
|
||||
location = result.get("location", "")
|
||||
lng, lat = location.split(",") if location else (None, None)
|
||||
|
||||
return BaseTool.success_response({
|
||||
"address": address,
|
||||
"formatted_address": result.get("formatted_address", ""),
|
||||
"location": location,
|
||||
"lng": float(lng) if lng else None,
|
||||
"lat": float(lat) if lat else None,
|
||||
"city": result.get("city", ""),
|
||||
"district": result.get("district", "")
|
||||
})
|
||||
|
||||
return BaseTool.fail_response(f"无法解析地址: {address}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"地理编码失败: {e}")
|
||||
return BaseTool.fail_response(f"地理编码错误: {str(e)}")
|
||||
|
||||
|
||||
class CalculateCenterTool(BaseTool):
|
||||
"""智能中心点工具 - 计算多个位置的最佳会面点
|
||||
|
||||
使用智能算法,综合考虑:
|
||||
- POI 密度:周边是否有足够的目标场所
|
||||
- 交通便利性:是否靠近地铁站/公交站
|
||||
- 公平性:对所有参与者的距离是否均衡
|
||||
"""
|
||||
|
||||
name: str = "calculate_center"
|
||||
description: str = """智能计算最佳会面中心点。
|
||||
|
||||
不同于简单的几何中心,本工具会:
|
||||
1. 在几何中心周围生成多个候选点
|
||||
2. 评估每个候选点的 POI 密度、交通便利性和公平性
|
||||
3. 返回综合得分最高的点作为最佳会面位置
|
||||
|
||||
这样可以避免中心点落在河流、荒地等不适合的位置。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"coordinates": {
|
||||
"type": "array",
|
||||
"description": "坐标点列表,每个元素包含 lng(经度)、lat(纬度)和可选的 name(名称)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lng": {"type": "number", "description": "经度"},
|
||||
"lat": {"type": "number", "description": "纬度"},
|
||||
"name": {"type": "string", "description": "位置名称(可选)"}
|
||||
},
|
||||
"required": ["lng", "lat"]
|
||||
}
|
||||
},
|
||||
"keywords": {
|
||||
"type": "string",
|
||||
"description": "搜索的场所类型,如'咖啡馆'、'餐厅',用于评估 POI 密度",
|
||||
"default": "咖啡馆"
|
||||
},
|
||||
"use_smart_algorithm": {
|
||||
"type": "boolean",
|
||||
"description": "是否使用智能算法(考虑 POI 密度和交通),默认 true",
|
||||
"default": True
|
||||
}
|
||||
},
|
||||
"required": ["coordinates"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
coordinates: List[Dict],
|
||||
keywords: str = "咖啡馆",
|
||||
use_smart_algorithm: bool = True
|
||||
) -> ToolResult:
|
||||
"""计算最佳中心点"""
|
||||
try:
|
||||
if not coordinates or len(coordinates) < 2:
|
||||
return BaseTool.fail_response("至少需要2个坐标点来计算中心")
|
||||
|
||||
recommender = self._get_recommender()
|
||||
|
||||
# 转换为 (lng, lat) 元组列表
|
||||
coord_tuples = [(c["lng"], c["lat"]) for c in coordinates]
|
||||
|
||||
if use_smart_algorithm:
|
||||
# 使用智能中心点算法
|
||||
center, evaluation_details = await recommender._calculate_smart_center(
|
||||
coord_tuples, keywords
|
||||
)
|
||||
logger.info(f"智能中心点算法完成,最优中心: {center}")
|
||||
else:
|
||||
# 使用简单几何中心
|
||||
center = recommender._calculate_center_point(coord_tuples)
|
||||
evaluation_details = {"algorithm": "geometric_center"}
|
||||
|
||||
# 计算每个点到中心的距离
|
||||
distances = []
|
||||
for c in coordinates:
|
||||
dist = recommender._calculate_distance(center, (c["lng"], c["lat"]))
|
||||
distances.append({
|
||||
"name": c.get("name", f"({c['lng']:.4f}, {c['lat']:.4f})"),
|
||||
"distance_to_center": round(dist, 0)
|
||||
})
|
||||
|
||||
max_dist = max(d["distance_to_center"] for d in distances)
|
||||
min_dist = min(d["distance_to_center"] for d in distances)
|
||||
|
||||
result = {
|
||||
"center": {
|
||||
"lng": round(center[0], 6),
|
||||
"lat": round(center[1], 6)
|
||||
},
|
||||
"algorithm": "smart" if use_smart_algorithm else "geometric",
|
||||
"input_count": len(coordinates),
|
||||
"distances": distances,
|
||||
"max_distance": max_dist,
|
||||
"fairness_score": round(100 - (max_dist - min_dist) / 100, 1)
|
||||
}
|
||||
|
||||
# 添加智能算法的评估详情
|
||||
if use_smart_algorithm and evaluation_details:
|
||||
result["evaluation"] = {
|
||||
"geo_center": evaluation_details.get("geo_center"),
|
||||
"best_score": evaluation_details.get("best_score"),
|
||||
"top_candidates": len(evaluation_details.get("all_candidates", []))
|
||||
}
|
||||
|
||||
return BaseTool.success_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算中心点失败: {e}")
|
||||
return BaseTool.fail_response(f"计算中心点错误: {str(e)}")
|
||||
|
||||
|
||||
class SearchPOITool(BaseTool):
|
||||
"""搜索POI工具 - 在指定位置周围搜索场所"""
|
||||
|
||||
name: str = "search_poi"
|
||||
description: str = """在指定中心点周围搜索各类场所(POI)。
|
||||
支持搜索:咖啡馆、餐厅、图书馆、健身房、KTV、电影院、商场等。
|
||||
返回场所的名称、地址、评分、距离等信息。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"center_lng": {
|
||||
"type": "number",
|
||||
"description": "中心点经度"
|
||||
},
|
||||
"center_lat": {
|
||||
"type": "number",
|
||||
"description": "中心点纬度"
|
||||
},
|
||||
"keywords": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词,如'咖啡馆'、'餐厅'、'图书馆'"
|
||||
},
|
||||
"radius": {
|
||||
"type": "integer",
|
||||
"description": "搜索半径(米),默认3000米",
|
||||
"default": 3000
|
||||
}
|
||||
},
|
||||
"required": ["center_lng", "center_lat", "keywords"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
center_lng: float,
|
||||
center_lat: float,
|
||||
keywords: str,
|
||||
radius: int = 3000
|
||||
) -> ToolResult:
|
||||
"""搜索POI"""
|
||||
try:
|
||||
recommender = self._get_recommender()
|
||||
center = f"{center_lng},{center_lat}"
|
||||
|
||||
places = await recommender._search_pois(
|
||||
location=center,
|
||||
keywords=keywords,
|
||||
radius=radius,
|
||||
types="",
|
||||
offset=20
|
||||
)
|
||||
|
||||
if not places:
|
||||
return BaseTool.fail_response(
|
||||
f"在 ({center_lng:.4f}, {center_lat:.4f}) 附近 {radius}米范围内"
|
||||
f"未找到与 '{keywords}' 相关的场所"
|
||||
)
|
||||
|
||||
# 简化返回数据
|
||||
simplified = []
|
||||
for p in places[:15]: # 最多返回15个
|
||||
biz_ext = p.get("biz_ext", {}) or {}
|
||||
location = p.get("location", "")
|
||||
lng, lat = location.split(",") if location else (0, 0)
|
||||
|
||||
# 计算到中心的距离
|
||||
distance = recommender._calculate_distance(
|
||||
(center_lng, center_lat),
|
||||
(float(lng), float(lat))
|
||||
) if location else 0
|
||||
|
||||
simplified.append({
|
||||
"name": p.get("name", ""),
|
||||
"address": p.get("address", ""),
|
||||
"rating": biz_ext.get("rating", "N/A"),
|
||||
"cost": biz_ext.get("cost", ""),
|
||||
"location": location,
|
||||
"lng": float(lng) if lng else None,
|
||||
"lat": float(lat) if lat else None,
|
||||
"distance": round(distance, 0),
|
||||
"tel": p.get("tel", ""),
|
||||
"tag": p.get("tag", ""),
|
||||
"type": p.get("type", "")
|
||||
})
|
||||
|
||||
# 按距离排序
|
||||
simplified.sort(key=lambda x: x.get("distance", 9999))
|
||||
|
||||
return BaseTool.success_response({
|
||||
"places": simplified,
|
||||
"count": len(simplified),
|
||||
"keywords": keywords,
|
||||
"center": {"lng": center_lng, "lat": center_lat},
|
||||
"radius": radius
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"POI搜索失败: {e}")
|
||||
return BaseTool.fail_response(f"POI搜索错误: {str(e)}")
|
||||
|
||||
|
||||
class GenerateRecommendationTool(BaseTool):
|
||||
"""智能推荐工具 - 使用 LLM 生成个性化推荐结果
|
||||
|
||||
结合规则评分和 LLM 智能评分,生成更精准的推荐:
|
||||
- 规则评分:基于距离、评分、热度等客观指标
|
||||
- LLM 评分:理解用户需求语义,评估场所匹配度
|
||||
"""
|
||||
|
||||
name: str = "generate_recommendation"
|
||||
description: str = """智能生成会面地点推荐。
|
||||
|
||||
本工具使用双层评分系统:
|
||||
1. 规则评分(40%):基于距离、评分、热度等客观指标
|
||||
2. LLM 智能评分(60%):理解用户需求,评估场所特色与需求的匹配度
|
||||
|
||||
最终生成个性化的推荐理由,帮助用户做出最佳选择。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"places": {
|
||||
"type": "array",
|
||||
"description": "候选场所列表(来自search_poi的结果)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "场所名称"},
|
||||
"address": {"type": "string", "description": "地址"},
|
||||
"rating": {"type": "string", "description": "评分"},
|
||||
"distance": {"type": "number", "description": "距中心点距离"},
|
||||
"location": {"type": "string", "description": "坐标"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"center": {
|
||||
"type": "object",
|
||||
"description": "中心点坐标",
|
||||
"properties": {
|
||||
"lng": {"type": "number", "description": "经度"},
|
||||
"lat": {"type": "number", "description": "纬度"}
|
||||
},
|
||||
"required": ["lng", "lat"]
|
||||
},
|
||||
"participant_locations": {
|
||||
"type": "array",
|
||||
"description": "参与者位置名称列表,用于 LLM 评估公平性",
|
||||
"items": {"type": "string"},
|
||||
"default": []
|
||||
},
|
||||
"keywords": {
|
||||
"type": "string",
|
||||
"description": "搜索的场所类型,如'咖啡馆'、'餐厅'",
|
||||
"default": "咖啡馆"
|
||||
},
|
||||
"user_requirements": {
|
||||
"type": "string",
|
||||
"description": "用户的特殊需求,如'停车方便'、'环境安静'",
|
||||
"default": ""
|
||||
},
|
||||
"recommendation_count": {
|
||||
"type": "integer",
|
||||
"description": "推荐数量,默认5个",
|
||||
"default": 5
|
||||
},
|
||||
"use_llm_ranking": {
|
||||
"type": "boolean",
|
||||
"description": "是否使用 LLM 智能排序,默认 true",
|
||||
"default": True
|
||||
}
|
||||
},
|
||||
"required": ["places", "center"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
places: List[Dict],
|
||||
center: Dict,
|
||||
participant_locations: List[str] = None,
|
||||
keywords: str = "咖啡馆",
|
||||
user_requirements: str = "",
|
||||
recommendation_count: int = 5,
|
||||
use_llm_ranking: bool = True
|
||||
) -> ToolResult:
|
||||
"""智能生成推荐"""
|
||||
try:
|
||||
if not places:
|
||||
return BaseTool.fail_response("没有候选场所可供推荐")
|
||||
|
||||
recommender = self._get_recommender()
|
||||
center_point = (center["lng"], center["lat"])
|
||||
|
||||
# 1. 先用规则评分进行初步排序
|
||||
ranked = recommender._rank_places(
|
||||
places=places,
|
||||
center_point=center_point,
|
||||
user_requirements=user_requirements,
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
# 2. 如果启用 LLM 智能排序,进行重排序
|
||||
if use_llm_ranking and participant_locations:
|
||||
logger.info("启用 LLM 智能排序")
|
||||
ranked = await recommender._llm_smart_ranking(
|
||||
places=ranked,
|
||||
user_requirements=user_requirements,
|
||||
participant_locations=participant_locations or [],
|
||||
keywords=keywords,
|
||||
top_n=recommendation_count + 3 # 多取几个以便筛选
|
||||
)
|
||||
|
||||
# 取前N个推荐
|
||||
top_places = ranked[:recommendation_count]
|
||||
|
||||
# 生成推荐结果
|
||||
recommendations = []
|
||||
for i, place in enumerate(top_places, 1):
|
||||
score = place.get("_final_score") or place.get("_score", 0)
|
||||
distance = place.get("_distance") or place.get("distance", 0)
|
||||
rating = place.get("_raw_rating") or place.get("rating", "N/A")
|
||||
|
||||
# 优先使用 LLM 生成的理由
|
||||
llm_reason = place.get("_llm_reason", "")
|
||||
rule_reason = place.get("_recommendation_reason", "")
|
||||
|
||||
if llm_reason:
|
||||
reasons = [llm_reason]
|
||||
elif rule_reason:
|
||||
reasons = [rule_reason]
|
||||
else:
|
||||
# 兜底:构建基础推荐理由
|
||||
reasons = []
|
||||
if distance <= 500:
|
||||
reasons.append("距离中心点很近")
|
||||
elif distance <= 1000:
|
||||
reasons.append("距离适中")
|
||||
|
||||
if rating != "N/A":
|
||||
try:
|
||||
r = float(rating)
|
||||
if r >= 4.5:
|
||||
reasons.append("口碑优秀")
|
||||
elif r >= 4.0:
|
||||
reasons.append("评价良好")
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if not reasons:
|
||||
reasons = ["综合评分较高"]
|
||||
|
||||
recommendations.append({
|
||||
"rank": i,
|
||||
"name": place.get("name", ""),
|
||||
"address": place.get("address", ""),
|
||||
"rating": str(rating) if rating else "N/A",
|
||||
"distance": round(distance, 0),
|
||||
"score": round(score, 1),
|
||||
"llm_score": place.get("_llm_score", 0),
|
||||
"tel": place.get("tel", ""),
|
||||
"reasons": reasons,
|
||||
"location": place.get("location", ""),
|
||||
"scoring_method": "llm+rule" if place.get("_llm_score") else "rule"
|
||||
})
|
||||
|
||||
return BaseTool.success_response({
|
||||
"recommendations": recommendations,
|
||||
"total_candidates": len(places),
|
||||
"user_requirements": user_requirements,
|
||||
"center": center,
|
||||
"llm_ranking_used": use_llm_ranking and bool(participant_locations)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成推荐失败: {e}")
|
||||
return BaseTool.fail_response(f"生成推荐错误: {str(e)}")
|
||||
|
||||
|
||||
# 导出所有工具
|
||||
__all__ = [
|
||||
"GeocodeTool",
|
||||
"CalculateCenterTool",
|
||||
"SearchPOITool",
|
||||
"GenerateRecommendationTool"
|
||||
]
|
||||
Reference in New Issue
Block a user