first commit

This commit is contained in:
ytc1012
2026-02-04 16:11:55 +08:00
commit 0f3ee050dc
165 changed files with 25795 additions and 0 deletions

9
MeetSpot/app/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
# Python version check: 3.11-3.13
import sys
if sys.version_info < (3, 11) or sys.version_info > (3, 13):
print(
"Warning: Unsupported Python version {ver}, please use 3.11-3.13".format(
ver=".".join(map(str, sys.version_info))
)
)

View File

@@ -0,0 +1,20 @@
"""MeetSpot Agent Module - 基于 OpenManus 架构的智能推荐 Agent"""
from app.agent.base import BaseAgent
from app.agent.meetspot_agent import MeetSpotAgent, create_meetspot_agent
from app.agent.tools import (
CalculateCenterTool,
GeocodeTool,
GenerateRecommendationTool,
SearchPOITool,
)
__all__ = [
"BaseAgent",
"MeetSpotAgent",
"create_meetspot_agent",
"GeocodeTool",
"CalculateCenterTool",
"SearchPOITool",
"GenerateRecommendationTool",
]

171
MeetSpot/app/agent/base.py Normal file
View File

@@ -0,0 +1,171 @@
"""Agent 基类 - 参考 OpenManus BaseAgent 设计"""
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field, model_validator
from app.llm import LLM
from app.logger import logger
from app.schema import AgentState, Memory, Message, ROLE_TYPE
class BaseAgent(BaseModel, ABC):
"""Agent 基类
提供基础的状态管理、记忆管理和执行循环。
子类需要实现 step() 方法来定义具体行为。
"""
# 核心属性
name: str = Field(default="BaseAgent", description="Agent 名称")
description: Optional[str] = Field(default=None, description="Agent 描述")
# 提示词
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
next_step_prompt: Optional[str] = Field(default=None, description="下一步提示词")
# 依赖
llm: Optional[LLM] = Field(default=None, description="LLM 实例")
memory: Memory = Field(default_factory=Memory, description="Agent 记忆")
state: AgentState = Field(default=AgentState.IDLE, description="当前状态")
# 执行控制
max_steps: int = Field(default=10, description="最大执行步数")
current_step: int = Field(default=0, description="当前步数")
# 重复检测阈值
duplicate_threshold: int = 2
class Config:
arbitrary_types_allowed = True
extra = "allow"
@model_validator(mode="after")
def initialize_agent(self) -> "BaseAgent":
"""初始化 Agent"""
if self.llm is None:
try:
self.llm = LLM()
except Exception as e:
logger.warning(f"无法初始化 LLM: {e}")
if not isinstance(self.memory, Memory):
self.memory = Memory()
return self
@asynccontextmanager
async def state_context(self, new_state: AgentState):
"""状态上下文管理器,用于安全的状态转换"""
if not isinstance(new_state, AgentState):
raise ValueError(f"无效状态: {new_state}")
previous_state = self.state
self.state = new_state
try:
yield
except Exception as e:
self.state = AgentState.ERROR
raise e
finally:
self.state = previous_state
def update_memory(
self,
role: ROLE_TYPE,
content: str,
base64_image: Optional[str] = None,
**kwargs,
) -> None:
"""添加消息到记忆"""
message_map = {
"user": Message.user_message,
"system": Message.system_message,
"assistant": Message.assistant_message,
"tool": lambda content, **kw: Message.tool_message(content, **kw),
}
if role not in message_map:
raise ValueError(f"不支持的消息角色: {role}")
if role == "tool":
self.memory.add_message(message_map[role](content, **kwargs))
else:
self.memory.add_message(message_map[role](content, base64_image=base64_image))
async def run(self, request: Optional[str] = None) -> str:
"""执行 Agent 主循环
Args:
request: 可选的初始用户请求
Returns:
执行结果摘要
"""
if self.state != AgentState.IDLE:
raise RuntimeError(f"无法从状态 {self.state} 启动 Agent")
if request:
self.update_memory("user", request)
results: List[str] = []
async with self.state_context(AgentState.RUNNING):
while (
self.current_step < self.max_steps
and self.state != AgentState.FINISHED
):
self.current_step += 1
logger.info(f"执行步骤 {self.current_step}/{self.max_steps}")
step_result = await self.step()
# 检测卡住状态
if self.is_stuck():
self.handle_stuck_state()
results.append(f"Step {self.current_step}: {step_result}")
if self.current_step >= self.max_steps:
self.current_step = 0
self.state = AgentState.IDLE
results.append(f"已终止: 达到最大步数 ({self.max_steps})")
return "\n".join(results) if results else "未执行任何步骤"
@abstractmethod
async def step(self) -> str:
"""执行单步操作 - 子类必须实现"""
pass
def handle_stuck_state(self):
"""处理卡住状态"""
stuck_prompt = "检测到重复响应。请考虑新策略,避免重复已尝试过的无效路径。"
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt or ''}"
logger.warning(f"Agent 检测到卡住状态,已添加提示")
def is_stuck(self) -> bool:
"""检测是否陷入循环"""
if len(self.memory.messages) < 2:
return False
last_message = self.memory.messages[-1]
if not last_message.content:
return False
# 统计相同内容出现次数
duplicate_count = sum(
1
for msg in reversed(self.memory.messages[:-1])
if msg.role == "assistant" and msg.content == last_message.content
)
return duplicate_count >= self.duplicate_threshold
@property
def messages(self) -> List[Message]:
"""获取记忆中的消息列表"""
return self.memory.messages
@messages.setter
def messages(self, value: List[Message]):
"""设置记忆中的消息列表"""
self.memory.messages = value

View File

@@ -0,0 +1,361 @@
"""MeetSpotAgent - 智能会面地点推荐 Agent
基于 ReAct 模式实现的智能推荐代理,通过工具调用完成地点推荐任务。
"""
import json
from typing import Any, Dict, List, Optional
from pydantic import Field
from app.agent.base import BaseAgent
from app.agent.tools import (
CalculateCenterTool,
GeocodeTool,
GenerateRecommendationTool,
SearchPOITool,
)
from app.llm import LLM
from app.logger import logger
from app.schema import AgentState, Message
from app.tool.tool_collection import ToolCollection
SYSTEM_PROMPT = """你是 MeetSpot 智能会面助手,帮助用户找到最佳会面地点。
## 你的能力
你可以使用以下工具来完成任务:
1. **geocode** - 地理编码
- 将地址转换为经纬度坐标
- 支持大学简称(北大、清华)、地标、商圈等
- 返回坐标和格式化地址
2. **calculate_center** - 计算中心点
- 计算多个位置的几何中心
- 作为最佳会面位置的参考点
- 使用球面几何确保精确
3. **search_poi** - 搜索场所
- 在中心点附近搜索各类场所
- 支持咖啡馆、餐厅、图书馆、健身房等
- 返回名称、地址、评分、距离等
4. **generate_recommendation** - 生成推荐
- 分析搜索结果
- 根据评分、距离、用户需求排序
- 生成个性化推荐理由
## 工作流程
请按以下步骤执行:
1. **理解任务** - 分析用户提供的位置和需求
2. **地理编码** - 依次对每个地址使用 geocode 获取坐标
3. **计算中心** - 使用 calculate_center 计算最佳会面点
4. **搜索场所** - 使用 search_poi 在中心点附近搜索
5. **生成推荐** - 使用 generate_recommendation 生成最终推荐
## 输出要求
- 推荐 3-5 个最佳场所
- 为每个场所说明推荐理由(距离、评分、特色)
- 考虑用户的特殊需求(停车、安静、商务等)
- 使用中文回复
## 注意事项
- 确保在调用工具前已获取所有必要参数
- 如果地址解析失败,提供具体的错误信息和建议
- 如果搜索无结果,尝试调整搜索关键词或扩大半径
"""
class MeetSpotAgent(BaseAgent):
"""MeetSpot 智能会面推荐 Agent
基于 ReAct 模式的智能代理,通过 think() -> act() 循环完成推荐任务。
"""
name: str = "MeetSpotAgent"
description: str = "智能会面地点推荐助手"
system_prompt: str = SYSTEM_PROMPT
next_step_prompt: str = "请继续执行下一步,或者如果已完成所有工具调用,请生成最终推荐结果。"
max_steps: int = 15 # 允许更多步骤以完成复杂任务
# 工具集合
available_tools: ToolCollection = Field(default=None)
# 当前工具调用
tool_calls: List[Any] = Field(default_factory=list)
# 存储中间结果
geocode_results: List[Dict] = Field(default_factory=list)
center_point: Optional[Dict] = None
search_results: List[Dict] = Field(default_factory=list)
class Config:
arbitrary_types_allowed = True
extra = "allow"
def __init__(self, **data):
super().__init__(**data)
# 初始化工具集合
if self.available_tools is None:
self.available_tools = ToolCollection(
GeocodeTool(),
CalculateCenterTool(),
SearchPOITool(),
GenerateRecommendationTool()
)
async def step(self) -> str:
"""执行一步: think + act
Returns:
步骤执行结果的描述
"""
# Think: 决定下一步行动
should_continue = await self.think()
if not should_continue:
self.state = AgentState.FINISHED
return "任务完成"
# Act: 执行工具调用
result = await self.act()
return result
async def think(self) -> bool:
"""思考阶段 - 决定下一步行动
使用 LLM 分析当前状态,决定是否需要调用工具以及调用哪个工具。
Returns:
是否需要继续执行
"""
# 构建消息
messages = self.memory.messages.copy()
# 添加提示引导下一步
if self.next_step_prompt and self.current_step > 1:
messages.append(Message.user_message(self.next_step_prompt))
# 调用 LLM 获取响应
response = await self.llm.ask_tool(
messages=messages,
system_msgs=[Message.system_message(self.system_prompt)],
tools=self.available_tools.to_params(),
tool_choice="auto"
)
if response is None:
logger.warning("LLM 返回空响应")
return False
# 提取工具调用和内容
self.tool_calls = response.tool_calls or []
content = response.content or ""
logger.info(f"Agent 思考: {content[:200]}..." if len(content) > 200 else f"Agent 思考: {content}")
if self.tool_calls:
tool_names = [tc.function.name for tc in self.tool_calls]
logger.info(f"选择工具: {tool_names}")
# 保存 assistant 消息到记忆
if self.tool_calls:
# 带工具调用的消息
tool_calls_data = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
}
for tc in self.tool_calls
]
self.memory.add_message(Message(
role="assistant",
content=content,
tool_calls=tool_calls_data
))
elif content:
# 纯文本消息(可能是最终回复)
self.memory.add_message(Message.assistant_message(content))
# 如果没有工具调用且有内容,可能是最终回复
if "推荐" in content and len(content) > 100:
return False # 结束循环
return bool(self.tool_calls) or bool(content)
async def act(self) -> str:
"""行动阶段 - 执行工具调用
执行思考阶段决定的工具调用,并将结果添加到记忆。
Returns:
工具执行结果的描述
"""
if not self.tool_calls:
# 没有工具调用,返回最后一条消息的内容
return self.memory.messages[-1].content or "无操作"
results = []
for call in self.tool_calls:
tool_name = call.function.name
tool_args = call.function.arguments
try:
# 解析参数
args = json.loads(tool_args) if isinstance(tool_args, str) else tool_args
# 执行工具
logger.info(f"执行工具: {tool_name}, 参数: {args}")
result = await self.available_tools.execute(name=tool_name, tool_input=args)
# 保存中间结果
self._save_intermediate_result(tool_name, result, args)
# 将工具结果添加到记忆
result_str = str(result)
self.memory.add_message(Message.tool_message(
content=result_str,
tool_call_id=call.id,
name=tool_name
))
logger.info(f"工具 {tool_name} 完成")
results.append(f"{tool_name}: 成功")
except Exception as e:
error_msg = f"工具执行失败: {str(e)}"
logger.error(f"{tool_name} {error_msg}")
# 添加错误消息到记忆
self.memory.add_message(Message.tool_message(
content=error_msg,
tool_call_id=call.id,
name=tool_name
))
results.append(f"{tool_name}: 失败 - {str(e)}")
return " | ".join(results)
def _save_intermediate_result(self, tool_name: str, result: Any, args: Dict) -> None:
"""保存工具执行的中间结果
Args:
tool_name: 工具名称
result: 工具执行结果
args: 工具参数
"""
try:
# 解析结果
if hasattr(result, 'output') and result.output:
data = json.loads(result.output) if isinstance(result.output, str) else result.output
else:
return
if tool_name == "geocode" and data:
self.geocode_results.append({
"address": args.get("address", ""),
"lng": data.get("lng"),
"lat": data.get("lat"),
"formatted_address": data.get("formatted_address", "")
})
elif tool_name == "calculate_center" and data:
self.center_point = data.get("center")
elif tool_name == "search_poi" and data:
places = data.get("places", [])
self.search_results.extend(places)
except Exception as e:
logger.debug(f"保存中间结果时出错: {e}")
async def recommend(
self,
locations: List[str],
keywords: str = "咖啡馆",
requirements: str = ""
) -> Dict:
"""执行推荐任务
这是 Agent 的主要入口方法,接收用户输入并返回推荐结果。
Args:
locations: 参与者位置列表
keywords: 搜索关键词(场所类型)
requirements: 用户特殊需求
Returns:
包含推荐结果的字典
"""
# 重置状态
self.geocode_results = []
self.center_point = None
self.search_results = []
self.current_step = 0
self.state = AgentState.IDLE
self.memory.clear()
# 构建任务描述
locations_str = "".join(locations)
task = f"""请帮我找到适合会面的地点:
**参与者位置**{locations_str}
**想找的场所类型**{keywords}
**特殊需求**{requirements or "无特殊需求"}
请按照工作流程执行:
1. 先用 geocode 工具获取每个位置的坐标
2. 用 calculate_center 计算中心点
3. 用 search_poi 搜索附近的 {keywords}
4. 用 generate_recommendation 生成推荐
最后请用中文总结推荐结果。"""
# 执行任务
result = await self.run(task)
# 格式化返回结果
return self._format_result(result)
def _format_result(self, raw_result: str) -> Dict:
"""格式化最终结果
Args:
raw_result: Agent 执行的原始结果
Returns:
格式化的结果字典
"""
# 获取最后一条 assistant 消息作为最终推荐
final_recommendation = ""
for msg in reversed(self.memory.messages):
if msg.role == "assistant" and msg.content:
final_recommendation = msg.content
break
return {
"success": self.state == AgentState.IDLE, # IDLE 表示正常完成
"recommendation": final_recommendation,
"geocode_results": self.geocode_results,
"center_point": self.center_point,
"search_results": self.search_results[:10], # 限制返回数量
"steps_executed": self.current_step,
"raw_output": raw_result
}
# 创建默认 Agent 实例的工厂函数
def create_meetspot_agent() -> MeetSpotAgent:
"""创建 MeetSpotAgent 实例
Returns:
配置好的 MeetSpotAgent 实例
"""
return MeetSpotAgent()

514
MeetSpot/app/agent/tools.py Normal file
View File

@@ -0,0 +1,514 @@
"""MeetSpot Agent 工具集 - 封装推荐系统的核心功能"""
import json
from typing import Any, Dict, List, Optional
from pydantic import Field
from app.tool.base import BaseTool, ToolResult
from app.logger import logger
class GeocodeTool(BaseTool):
"""地理编码工具 - 将地址转换为经纬度坐标"""
name: str = "geocode"
description: str = """将地址或地点名称转换为经纬度坐标。
支持各种地址格式:
- 完整地址:'北京市海淀区中关村大街1号'
- 大学简称:'北大''清华''复旦'(自动扩展为完整地址)
- 知名地标:'天安门''外滩''广州塔'
- 商圈区域:'三里屯''王府井'
返回地址的经纬度坐标和格式化地址。"""
parameters: dict = {
"type": "object",
"properties": {
"address": {
"type": "string",
"description": "地址或地点名称,如'北京大学''上海市浦东新区陆家嘴'"
}
},
"required": ["address"]
}
class Config:
arbitrary_types_allowed = True
def _get_recommender(self):
"""延迟加载推荐器,并确保 API key 已设置"""
if not hasattr(self, '_cached_recommender'):
from app.tool.meetspot_recommender import CafeRecommender
from app.config import config
recommender = CafeRecommender()
# 确保 API key 已设置
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
recommender.api_key = config.amap.api_key
object.__setattr__(self, '_cached_recommender', recommender)
return self._cached_recommender
async def execute(self, address: str) -> ToolResult:
"""执行地理编码"""
try:
recommender = self._get_recommender()
result = await recommender._geocode(address)
if result:
location = result.get("location", "")
lng, lat = location.split(",") if location else (None, None)
return BaseTool.success_response({
"address": address,
"formatted_address": result.get("formatted_address", ""),
"location": location,
"lng": float(lng) if lng else None,
"lat": float(lat) if lat else None,
"city": result.get("city", ""),
"district": result.get("district", "")
})
return BaseTool.fail_response(f"无法解析地址: {address}")
except Exception as e:
logger.error(f"地理编码失败: {e}")
return BaseTool.fail_response(f"地理编码错误: {str(e)}")
class CalculateCenterTool(BaseTool):
"""智能中心点工具 - 计算多个位置的最佳会面点
使用智能算法,综合考虑:
- POI 密度:周边是否有足够的目标场所
- 交通便利性:是否靠近地铁站/公交站
- 公平性:对所有参与者的距离是否均衡
"""
name: str = "calculate_center"
description: str = """智能计算最佳会面中心点。
不同于简单的几何中心,本工具会:
1. 在几何中心周围生成多个候选点
2. 评估每个候选点的 POI 密度、交通便利性和公平性
3. 返回综合得分最高的点作为最佳会面位置
这样可以避免中心点落在河流、荒地等不适合的位置。"""
parameters: dict = {
"type": "object",
"properties": {
"coordinates": {
"type": "array",
"description": "坐标点列表,每个元素包含 lng经度、lat纬度和可选的 name名称",
"items": {
"type": "object",
"properties": {
"lng": {"type": "number", "description": "经度"},
"lat": {"type": "number", "description": "纬度"},
"name": {"type": "string", "description": "位置名称(可选)"}
},
"required": ["lng", "lat"]
}
},
"keywords": {
"type": "string",
"description": "搜索的场所类型,如'咖啡馆''餐厅',用于评估 POI 密度",
"default": "咖啡馆"
},
"use_smart_algorithm": {
"type": "boolean",
"description": "是否使用智能算法(考虑 POI 密度和交通),默认 true",
"default": True
}
},
"required": ["coordinates"]
}
class Config:
arbitrary_types_allowed = True
def _get_recommender(self):
"""延迟加载推荐器,并确保 API key 已设置"""
if not hasattr(self, '_cached_recommender'):
from app.tool.meetspot_recommender import CafeRecommender
from app.config import config
recommender = CafeRecommender()
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
recommender.api_key = config.amap.api_key
object.__setattr__(self, '_cached_recommender', recommender)
return self._cached_recommender
async def execute(
self,
coordinates: List[Dict],
keywords: str = "咖啡馆",
use_smart_algorithm: bool = True
) -> ToolResult:
"""计算最佳中心点"""
try:
if not coordinates or len(coordinates) < 2:
return BaseTool.fail_response("至少需要2个坐标点来计算中心")
recommender = self._get_recommender()
# 转换为 (lng, lat) 元组列表
coord_tuples = [(c["lng"], c["lat"]) for c in coordinates]
if use_smart_algorithm:
# 使用智能中心点算法
center, evaluation_details = await recommender._calculate_smart_center(
coord_tuples, keywords
)
logger.info(f"智能中心点算法完成,最优中心: {center}")
else:
# 使用简单几何中心
center = recommender._calculate_center_point(coord_tuples)
evaluation_details = {"algorithm": "geometric_center"}
# 计算每个点到中心的距离
distances = []
for c in coordinates:
dist = recommender._calculate_distance(center, (c["lng"], c["lat"]))
distances.append({
"name": c.get("name", f"({c['lng']:.4f}, {c['lat']:.4f})"),
"distance_to_center": round(dist, 0)
})
max_dist = max(d["distance_to_center"] for d in distances)
min_dist = min(d["distance_to_center"] for d in distances)
result = {
"center": {
"lng": round(center[0], 6),
"lat": round(center[1], 6)
},
"algorithm": "smart" if use_smart_algorithm else "geometric",
"input_count": len(coordinates),
"distances": distances,
"max_distance": max_dist,
"fairness_score": round(100 - (max_dist - min_dist) / 100, 1)
}
# 添加智能算法的评估详情
if use_smart_algorithm and evaluation_details:
result["evaluation"] = {
"geo_center": evaluation_details.get("geo_center"),
"best_score": evaluation_details.get("best_score"),
"top_candidates": len(evaluation_details.get("all_candidates", []))
}
return BaseTool.success_response(result)
except Exception as e:
logger.error(f"计算中心点失败: {e}")
return BaseTool.fail_response(f"计算中心点错误: {str(e)}")
class SearchPOITool(BaseTool):
"""搜索POI工具 - 在指定位置周围搜索场所"""
name: str = "search_poi"
description: str = """在指定中心点周围搜索各类场所POI
支持搜索咖啡馆、餐厅、图书馆、健身房、KTV、电影院、商场等。
返回场所的名称、地址、评分、距离等信息。"""
parameters: dict = {
"type": "object",
"properties": {
"center_lng": {
"type": "number",
"description": "中心点经度"
},
"center_lat": {
"type": "number",
"description": "中心点纬度"
},
"keywords": {
"type": "string",
"description": "搜索关键词,如'咖啡馆''餐厅''图书馆'"
},
"radius": {
"type": "integer",
"description": "搜索半径默认3000米",
"default": 3000
}
},
"required": ["center_lng", "center_lat", "keywords"]
}
class Config:
arbitrary_types_allowed = True
def _get_recommender(self):
"""延迟加载推荐器,并确保 API key 已设置"""
if not hasattr(self, '_cached_recommender'):
from app.tool.meetspot_recommender import CafeRecommender
from app.config import config
recommender = CafeRecommender()
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
recommender.api_key = config.amap.api_key
object.__setattr__(self, '_cached_recommender', recommender)
return self._cached_recommender
async def execute(
self,
center_lng: float,
center_lat: float,
keywords: str,
radius: int = 3000
) -> ToolResult:
"""搜索POI"""
try:
recommender = self._get_recommender()
center = f"{center_lng},{center_lat}"
places = await recommender._search_pois(
location=center,
keywords=keywords,
radius=radius,
types="",
offset=20
)
if not places:
return BaseTool.fail_response(
f"在 ({center_lng:.4f}, {center_lat:.4f}) 附近 {radius}米范围内"
f"未找到与 '{keywords}' 相关的场所"
)
# 简化返回数据
simplified = []
for p in places[:15]: # 最多返回15个
biz_ext = p.get("biz_ext", {}) or {}
location = p.get("location", "")
lng, lat = location.split(",") if location else (0, 0)
# 计算到中心的距离
distance = recommender._calculate_distance(
(center_lng, center_lat),
(float(lng), float(lat))
) if location else 0
simplified.append({
"name": p.get("name", ""),
"address": p.get("address", ""),
"rating": biz_ext.get("rating", "N/A"),
"cost": biz_ext.get("cost", ""),
"location": location,
"lng": float(lng) if lng else None,
"lat": float(lat) if lat else None,
"distance": round(distance, 0),
"tel": p.get("tel", ""),
"tag": p.get("tag", ""),
"type": p.get("type", "")
})
# 按距离排序
simplified.sort(key=lambda x: x.get("distance", 9999))
return BaseTool.success_response({
"places": simplified,
"count": len(simplified),
"keywords": keywords,
"center": {"lng": center_lng, "lat": center_lat},
"radius": radius
})
except Exception as e:
logger.error(f"POI搜索失败: {e}")
return BaseTool.fail_response(f"POI搜索错误: {str(e)}")
class GenerateRecommendationTool(BaseTool):
"""智能推荐工具 - 使用 LLM 生成个性化推荐结果
结合规则评分和 LLM 智能评分,生成更精准的推荐:
- 规则评分:基于距离、评分、热度等客观指标
- LLM 评分:理解用户需求语义,评估场所匹配度
"""
name: str = "generate_recommendation"
description: str = """智能生成会面地点推荐。
本工具使用双层评分系统:
1. 规则评分40%):基于距离、评分、热度等客观指标
2. LLM 智能评分60%):理解用户需求,评估场所特色与需求的匹配度
最终生成个性化的推荐理由,帮助用户做出最佳选择。"""
parameters: dict = {
"type": "object",
"properties": {
"places": {
"type": "array",
"description": "候选场所列表来自search_poi的结果",
"items": {
"type": "object",
"properties": {
"name": {"type": "string", "description": "场所名称"},
"address": {"type": "string", "description": "地址"},
"rating": {"type": "string", "description": "评分"},
"distance": {"type": "number", "description": "距中心点距离"},
"location": {"type": "string", "description": "坐标"}
}
}
},
"center": {
"type": "object",
"description": "中心点坐标",
"properties": {
"lng": {"type": "number", "description": "经度"},
"lat": {"type": "number", "description": "纬度"}
},
"required": ["lng", "lat"]
},
"participant_locations": {
"type": "array",
"description": "参与者位置名称列表,用于 LLM 评估公平性",
"items": {"type": "string"},
"default": []
},
"keywords": {
"type": "string",
"description": "搜索的场所类型,如'咖啡馆''餐厅'",
"default": "咖啡馆"
},
"user_requirements": {
"type": "string",
"description": "用户的特殊需求,如'停车方便''环境安静'",
"default": ""
},
"recommendation_count": {
"type": "integer",
"description": "推荐数量默认5个",
"default": 5
},
"use_llm_ranking": {
"type": "boolean",
"description": "是否使用 LLM 智能排序,默认 true",
"default": True
}
},
"required": ["places", "center"]
}
class Config:
arbitrary_types_allowed = True
def _get_recommender(self):
"""延迟加载推荐器,并确保 API key 已设置"""
if not hasattr(self, '_cached_recommender'):
from app.tool.meetspot_recommender import CafeRecommender
from app.config import config
recommender = CafeRecommender()
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
recommender.api_key = config.amap.api_key
object.__setattr__(self, '_cached_recommender', recommender)
return self._cached_recommender
async def execute(
self,
places: List[Dict],
center: Dict,
participant_locations: List[str] = None,
keywords: str = "咖啡馆",
user_requirements: str = "",
recommendation_count: int = 5,
use_llm_ranking: bool = True
) -> ToolResult:
"""智能生成推荐"""
try:
if not places:
return BaseTool.fail_response("没有候选场所可供推荐")
recommender = self._get_recommender()
center_point = (center["lng"], center["lat"])
# 1. 先用规则评分进行初步排序
ranked = recommender._rank_places(
places=places,
center_point=center_point,
user_requirements=user_requirements,
keywords=keywords
)
# 2. 如果启用 LLM 智能排序,进行重排序
if use_llm_ranking and participant_locations:
logger.info("启用 LLM 智能排序")
ranked = await recommender._llm_smart_ranking(
places=ranked,
user_requirements=user_requirements,
participant_locations=participant_locations or [],
keywords=keywords,
top_n=recommendation_count + 3 # 多取几个以便筛选
)
# 取前N个推荐
top_places = ranked[:recommendation_count]
# 生成推荐结果
recommendations = []
for i, place in enumerate(top_places, 1):
score = place.get("_final_score") or place.get("_score", 0)
distance = place.get("_distance") or place.get("distance", 0)
rating = place.get("_raw_rating") or place.get("rating", "N/A")
# 优先使用 LLM 生成的理由
llm_reason = place.get("_llm_reason", "")
rule_reason = place.get("_recommendation_reason", "")
if llm_reason:
reasons = [llm_reason]
elif rule_reason:
reasons = [rule_reason]
else:
# 兜底:构建基础推荐理由
reasons = []
if distance <= 500:
reasons.append("距离中心点很近")
elif distance <= 1000:
reasons.append("距离适中")
if rating != "N/A":
try:
r = float(rating)
if r >= 4.5:
reasons.append("口碑优秀")
elif r >= 4.0:
reasons.append("评价良好")
except (ValueError, TypeError):
pass
if not reasons:
reasons = ["综合评分较高"]
recommendations.append({
"rank": i,
"name": place.get("name", ""),
"address": place.get("address", ""),
"rating": str(rating) if rating else "N/A",
"distance": round(distance, 0),
"score": round(score, 1),
"llm_score": place.get("_llm_score", 0),
"tel": place.get("tel", ""),
"reasons": reasons,
"location": place.get("location", ""),
"scoring_method": "llm+rule" if place.get("_llm_score") else "rule"
})
return BaseTool.success_response({
"recommendations": recommendations,
"total_candidates": len(places),
"user_requirements": user_requirements,
"center": center,
"llm_ranking_used": use_llm_ranking and bool(participant_locations)
})
except Exception as e:
logger.error(f"生成推荐失败: {e}")
return BaseTool.fail_response(f"生成推荐错误: {str(e)}")
# 导出所有工具
__all__ = [
"GeocodeTool",
"CalculateCenterTool",
"SearchPOITool",
"GenerateRecommendationTool"
]

View File

@@ -0,0 +1,2 @@
"""认证相关模块。"""

58
MeetSpot/app/auth/jwt.py Normal file
View File

@@ -0,0 +1,58 @@
"""JWT 工具函数与统一用户依赖。"""
import os
from datetime import datetime, timedelta
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.crud import get_user_by_id
from app.db.database import get_db
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "meetspot-dev-secret")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_DAYS = int(os.getenv("ACCESS_TOKEN_EXPIRE_DAYS", "7"))
bearer_scheme = HTTPBearer(auto_error=False)
def create_access_token(data: dict) -> str:
"""生成带过期时间的JWT。"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def decode_token(token: str) -> Optional[dict]:
"""解码并验证JWT。"""
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
return None
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
db: AsyncSession = Depends(get_db),
):
"""FastAPI 依赖:获取当前用户。"""
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="缺少认证信息"
)
payload = decode_token(credentials.credentials)
if not payload or "sub" not in payload:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌无效")
user = await get_user_by_id(db, payload["sub"])
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
return user

24
MeetSpot/app/auth/sms.py Normal file
View File

@@ -0,0 +1,24 @@
"""短信验证码(Mock版)。"""
from typing import Dict
MOCK_CODE = "123456"
_code_store: Dict[str, str] = {}
async def send_login_code(phone: str) -> str:
"""Mock发送验证码固定返回`123456`。
- 真实环境可替换为短信网关调用
- 这里简单记忆最后一次下发的验证码,便于后续校验扩展
"""
_code_store[phone] = MOCK_CODE
return MOCK_CODE
def validate_code(phone: str, code: str) -> bool:
"""校验验证码MVP阶段固定匹配Mock值。"""
return code == MOCK_CODE

315
MeetSpot/app/config.py Normal file
View File

@@ -0,0 +1,315 @@
import threading
import tomllib
from pathlib import Path
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
def get_project_root() -> Path:
"""Get the project root directory"""
return Path(__file__).resolve().parent.parent
PROJECT_ROOT = get_project_root()
WORKSPACE_ROOT = PROJECT_ROOT / "workspace"
class LLMSettings(BaseModel):
model: str = Field(..., description="Model name")
base_url: str = Field(..., description="API base URL")
api_key: str = Field(..., description="API key")
max_tokens: int = Field(4096, description="Maximum number of tokens per request")
max_input_tokens: Optional[int] = Field(
None,
description="Maximum input tokens to use across all requests (None for unlimited)",
)
temperature: float = Field(1.0, description="Sampling temperature")
api_type: str = Field(..., description="Azure, Openai, or Ollama")
api_version: str = Field(..., description="Azure Openai version if AzureOpenai")
class ProxySettings(BaseModel):
server: str = Field(None, description="Proxy server address")
username: Optional[str] = Field(None, description="Proxy username")
password: Optional[str] = Field(None, description="Proxy password")
class SearchSettings(BaseModel):
engine: str = Field(default="Google", description="Search engine the llm to use")
class AMapSettings(BaseModel):
"""高德地图API配置"""
api_key: str = Field(..., description="高德地图API密钥")
web_api_key: Optional[str] = Field(None, description="高德地图JavaScript API密钥")
class BrowserSettings(BaseModel):
headless: bool = Field(False, description="Whether to run browser in headless mode")
disable_security: bool = Field(
True, description="Disable browser security features"
)
extra_chromium_args: List[str] = Field(
default_factory=list, description="Extra arguments to pass to the browser"
)
chrome_instance_path: Optional[str] = Field(
None, description="Path to a Chrome instance to use"
)
wss_url: Optional[str] = Field(
None, description="Connect to a browser instance via WebSocket"
)
cdp_url: Optional[str] = Field(
None, description="Connect to a browser instance via CDP"
)
proxy: Optional[ProxySettings] = Field(
None, description="Proxy settings for the browser"
)
max_content_length: int = Field(
2000, description="Maximum length for content retrieval operations"
)
class SandboxSettings(BaseModel):
"""Configuration for the execution sandbox"""
use_sandbox: bool = Field(False, description="Whether to use the sandbox")
image: str = Field("python:3.12-slim", description="Base image")
work_dir: str = Field("/workspace", description="Container working directory")
memory_limit: str = Field("512m", description="Memory limit")
cpu_limit: float = Field(1.0, description="CPU limit")
timeout: int = Field(300, description="Default command timeout (seconds)")
network_enabled: bool = Field(
False, description="Whether network access is allowed"
)
class AppConfig(BaseModel):
llm: Dict[str, LLMSettings]
sandbox: Optional[SandboxSettings] = Field(
None, description="Sandbox configuration"
)
browser_config: Optional[BrowserSettings] = Field(
None, description="Browser configuration"
)
search_config: Optional[SearchSettings] = Field(
None, description="Search configuration"
)
amap: Optional[AMapSettings] = Field(
None, description="高德地图API配置"
)
class Config:
arbitrary_types_allowed = True
class Config:
_instance = None
_lock = threading.Lock()
_initialized = False
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
with self._lock:
if not self._initialized:
self._config = None
self._load_initial_config()
self._initialized = True
@staticmethod
def _get_config_path() -> Path:
root = PROJECT_ROOT
config_path = root / "config" / "config.toml"
if config_path.exists():
return config_path
example_path = root / "config" / "config.toml.example"
if example_path.exists():
return example_path
# 如果都没有找到,返回默认路径,让后续创建默认配置
return config_path
def _load_config(self) -> dict:
try:
config_path = self._get_config_path()
if not config_path.exists():
# 创建默认配置
default_config = {
"llm": {
"model": "gpt-3.5-turbo",
"api_key": "",
"base_url": "",
"max_tokens": 4096,
"temperature": 1.0,
"api_type": "",
"api_version": ""
},
"amap": {
"api_key": "",
"security_js_code": ""
},
"log": {
"level": "info",
"file": "logs/meetspot.log"
},
"server": {
"host": "0.0.0.0",
"port": 8000
}
}
return default_config
with config_path.open("rb") as f:
return tomllib.load(f)
except Exception as e:
# 如果加载失败,返回默认配置
print(f"Failed to load config file, using defaults: {e}")
return {
"llm": {
"model": "gpt-3.5-turbo",
"api_key": "",
"base_url": "",
"max_tokens": 4096,
"temperature": 1.0,
"api_type": "",
"api_version": ""
}
}
def _load_initial_config(self):
raw_config = self._load_config()
base_llm = raw_config.get("llm", {})
# 从环境变量读取敏感信息
import os
openai_api_key = os.getenv("OPENAI_API_KEY", "") or os.getenv("LLM_API_KEY", "")
amap_api_key = os.getenv("AMAP_API_KEY", "")
# 支持 Render 部署的环境变量配置
llm_base_url = os.getenv("LLM_API_BASE", "") or base_llm.get("base_url", "")
llm_model = os.getenv("LLM_MODEL", "") or base_llm.get("model", "gpt-3.5-turbo")
llm_overrides = {
k: v for k, v in raw_config.get("llm", {}).items() if isinstance(v, dict)
}
default_settings = {
"model": llm_model, # 优先使用环境变量
"base_url": llm_base_url, # 优先使用环境变量
"api_key": openai_api_key, # 从环境变量获取
"max_tokens": base_llm.get("max_tokens", 4096),
"max_input_tokens": base_llm.get("max_input_tokens"),
"temperature": base_llm.get("temperature", 1.0),
"api_type": base_llm.get("api_type", ""),
"api_version": base_llm.get("api_version", ""),
}
# handle browser config.
browser_config = raw_config.get("browser", {})
browser_settings = None
if browser_config:
# handle proxy settings.
proxy_config = browser_config.get("proxy", {})
proxy_settings = None
if proxy_config and proxy_config.get("server"):
proxy_settings = ProxySettings(
**{
k: v
for k, v in proxy_config.items()
if k in ["server", "username", "password"] and v
}
)
# filter valid browser config parameters.
valid_browser_params = {
k: v
for k, v in browser_config.items()
if k in BrowserSettings.__annotations__ and v is not None
}
# if there is proxy settings, add it to the parameters.
if proxy_settings:
valid_browser_params["proxy"] = proxy_settings
# only create BrowserSettings when there are valid parameters.
if valid_browser_params:
browser_settings = BrowserSettings(**valid_browser_params)
search_config = raw_config.get("search", {})
search_settings = None
if search_config:
search_settings = SearchSettings(**search_config)
sandbox_config = raw_config.get("sandbox", {})
if sandbox_config:
sandbox_settings = SandboxSettings(**sandbox_config)
else:
sandbox_settings = SandboxSettings()
# 处理高德地图API配置
amap_config = raw_config.get("amap", {})
amap_settings = None
# 优先使用环境变量中的 AMAP_API_KEY
if amap_api_key:
amap_settings = AMapSettings(
api_key=amap_api_key,
security_js_code=os.getenv("AMAP_SECURITY_JS_CODE", amap_config.get("security_js_code", ""))
)
elif amap_config and amap_config.get("api_key"):
amap_settings = AMapSettings(**amap_config)
config_dict = {
"llm": {
"default": default_settings,
**{
name: {**default_settings, **override_config}
for name, override_config in llm_overrides.items()
},
},
"sandbox": sandbox_settings,
"browser_config": browser_settings,
"search_config": search_settings,
"amap": amap_settings,
}
self._config = AppConfig(**config_dict)
@property
def llm(self) -> Dict[str, LLMSettings]:
return self._config.llm
@property
def sandbox(self) -> SandboxSettings:
return self._config.sandbox
@property
def browser_config(self) -> Optional[BrowserSettings]:
return self._config.browser_config
@property
def search_config(self) -> Optional[SearchSettings]:
return self._config.search_config
@property
def amap(self) -> Optional[AMapSettings]:
"""获取高德地图API配置"""
return self._config.amap
@property
def workspace_root(self) -> Path:
"""Get the workspace root directory"""
return WORKSPACE_ROOT
@property
def root_path(self) -> Path:
"""Get the root path of the application"""
return PROJECT_ROOT
config = Config()

View File

@@ -0,0 +1,120 @@
import os
import threading
import tomllib
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
def get_project_root() -> Path:
"""Get the project root directory"""
return Path(__file__).resolve().parent.parent
PROJECT_ROOT = get_project_root()
WORKSPACE_ROOT = PROJECT_ROOT / "workspace"
class AMapSettings(BaseModel):
"""高德地图API配置"""
api_key: str = Field(..., description="高德地图API密钥")
security_js_code: Optional[str] = Field(None, description="高德地图JavaScript API安全密钥")
class LogSettings(BaseModel):
"""日志配置"""
level: str = Field(default="INFO", description="日志级别")
file_path: str = Field(default="logs/meetspot.log", description="日志文件路径")
class AppConfig(BaseModel):
"""应用配置"""
amap: AMapSettings = Field(..., description="高德地图API配置")
log: Optional[LogSettings] = Field(default=LogSettings(), description="日志配置")
class Config:
arbitrary_types_allowed = True
class Config:
"""配置管理器(单例模式)"""
_instance = None
_lock = threading.Lock()
_initialized = False
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
with self._lock:
if not self._initialized:
self._config = None
self._load_initial_config()
self._initialized = True
def _load_initial_config(self):
"""加载初始配置"""
try:
# 首先尝试从环境变量加载Vercel 部署)
if os.getenv("AMAP_API_KEY"):
self._config = AppConfig(
amap=AMapSettings(
api_key=os.getenv("AMAP_API_KEY", ""),
security_js_code=os.getenv("AMAP_SECURITY_JS_CODE", "")
)
)
return
# 然后尝试从配置文件加载(本地开发)
config_path = PROJECT_ROOT / "config" / "config.toml"
if config_path.exists():
with open(config_path, "rb") as f:
toml_data = tomllib.load(f)
amap_config = toml_data.get("amap", {})
if not amap_config.get("api_key"):
raise ValueError("高德地图API密钥未配置")
self._config = AppConfig(
amap=AMapSettings(**amap_config),
log=LogSettings(**toml_data.get("log", {}))
)
else:
raise FileNotFoundError(f"配置文件不存在: {config_path}")
except Exception as e:
# 提供默认配置以防止启动失败
print(f"配置加载失败,使用默认配置: {e}")
self._config = AppConfig(
amap=AMapSettings(
api_key=os.getenv("AMAP_API_KEY", ""),
security_js_code=os.getenv("AMAP_SECURITY_JS_CODE", "")
)
)
def reload(self):
"""重新加载配置"""
with self._lock:
self._initialized = False
self._load_initial_config()
self._initialized = True
@property
def amap(self) -> AMapSettings:
"""获取高德地图配置"""
return self._config.amap
@property
def log(self) -> LogSettings:
"""获取日志配置"""
return self._config.log
# 全局配置实例
config = Config()

View File

@@ -0,0 +1,2 @@
"""数据库相关模块初始化。"""

50
MeetSpot/app/db/crud.py Normal file
View File

@@ -0,0 +1,50 @@
"""常用数据库操作封装。"""
from datetime import datetime
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user import User
def _default_nickname(phone: str) -> str:
suffix = phone[-4:] if len(phone) >= 4 else phone
return f"用户{suffix}"
async def get_user_by_phone(db: AsyncSession, phone: str) -> Optional[User]:
"""根据手机号查询用户。"""
stmt = select(User).where(User.phone == phone)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_user_by_id(db: AsyncSession, user_id: str) -> Optional[User]:
"""根据ID查询用户。"""
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def create_user(
db: AsyncSession, phone: str, nickname: Optional[str] = None, avatar_url: str = ""
) -> User:
"""创建新用户。"""
user = User(
phone=phone,
nickname=nickname or _default_nickname(phone),
avatar_url=avatar_url or "",
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def touch_last_login(db: AsyncSession, user: User) -> None:
"""更新用户最近登录时间。"""
user.last_login = datetime.utcnow()
await db.commit()

View File

@@ -0,0 +1,48 @@
"""数据库引擎与会话管理。
使用SQLite作为MVP默认存储保留通过环境变量`DATABASE_URL`切换到PostgreSQL的能力。
"""
import os
from pathlib import Path
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import declarative_base
# 项目根目录默认将SQLite数据库放在data目录下
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
DATA_DIR = PROJECT_ROOT / "data"
DATA_DIR.mkdir(exist_ok=True)
# 允许通过环境变量覆盖数据库连接串
DEFAULT_SQLITE_PATH = DATA_DIR / "meetspot.db"
DATABASE_URL = os.getenv(
"DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_SQLITE_PATH.as_posix()}"
)
# 创建异步引擎与会话工厂
engine = create_async_engine(DATABASE_URL, echo=False, future=True)
AsyncSessionLocal = async_sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False
)
# 统一的ORM基类
Base = declarative_base()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI 依赖:提供数据库会话并确保正确关闭。"""
async with AsyncSessionLocal() as session:
yield session
async def init_db() -> None:
"""在启动时创建数据库表。"""
# 延迟导入以避免循环依赖
from app import models # noqa: F401 确保所有模型已注册
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

View File

@@ -0,0 +1,652 @@
"""
MeetSpot Design Tokens - 单一真相来源
所有色彩、间距、字体系统的中心定义文件。
修改本文件会影响:
1. 基础模板 (templates/base.html)
2. 静态HTML (public/*.html)
3. 动态生成页面 (workspace/js_src/*.html)
WCAG 2.1 AA级对比度标准:
- 正文: ≥ 4.5:1
- 大文字: ≥ 3.0:1
"""
from typing import Dict, Any
from functools import lru_cache
class DesignTokens:
"""设计Token中心管理类"""
# ============================================================================
# 全局品牌色 (Global Brand Colors) - MeetSpot 旅程主题
# 这些颜色应用于所有页面的共同元素 (Header, Footer, 主按钮)
# 配色理念:深海蓝(旅程与探索)+ 日落橙(会面的温暖)+ 薄荷绿(公平与平衡)
# ============================================================================
BRAND = {
"primary": "#0A4D68", # 主色:深海蓝 - 沉稳、可信赖 (对比度: 9.12:1 ✓)
"primary_dark": "#05445E", # 暗深海蓝 - 悬停态 (对比度: 11.83:1 ✓)
"primary_light": "#088395", # 亮海蓝 - 装饰性元素 (对比度: 5.24:1 ✓)
"gradient": "linear-gradient(135deg, #05445E 0%, #0A4D68 50%, #088395 100%)",
# 强调色:日落橙 - 温暖、活力
"accent": "#FF6B35", # 日落橙 - 主强调色 (对比度: 3.55:1, 大文字用途)
"accent_light": "#FF8C61", # 亮橙 - 次要强调 (对比度: 2.87:1, 装饰用途)
# 次要色:薄荷绿 - 清新、平衡
"secondary": "#06D6A0", # 薄荷绿 (对比度: 2.28:1, 装饰用途)
# 功能色 - 全部WCAG AA级
"success": "#0C8A5D", # 成功绿 - 保持 (4.51:1 ✓)
"info": "#2563EB", # 信息蓝 - 保持 (5.17:1 ✓)
"warning": "#CA7205", # 警告橙 - 保持 (4.50:1 ✓)
"error": "#DC2626", # 错误红 - 保持 (4.83:1 ✓)
}
# ============================================================================
# 文字颜色系统 (Text Colors)
# 基于WCAG 2.1标准,所有文字色在白色背景上对比度 ≥ 4.5:1
# ============================================================================
TEXT = {
"primary": "#111827", # 主文字 (gray-900, 对比度 17.74:1 ✓)
"secondary": "#4B5563", # 次要文字 (gray-600, 对比度 7.56:1 ✓)
"tertiary": "#6B7280", # 三级文字 (gray-500, 对比度 4.83:1 ✓)
"muted": "#6B7280", # 弱化文字 - 修正 (原#9CA3AF: 2.54:1 -> 4.83:1, 使用tertiary色)
"disabled": "#9CA3AF", # 禁用文字 - 保持低对比度 (装饰性文字允许 <3:1)
"inverse": "#FFFFFF", # 反转文字 (深色背景上)
}
# ============================================================================
# 背景颜色系统 (Background Colors)
# ============================================================================
BACKGROUND = {
"primary": "#FFFFFF", # 主背景 (白色)
"secondary": "#F9FAFB", # 次要背景 (gray-50)
"tertiary": "#F3F4F6", # 三级背景 (gray-100)
"elevated": "#FFFFFF", # 卡片/浮动元素背景 (带阴影)
"overlay": "rgba(0, 0, 0, 0.5)", # 蒙层
}
# ============================================================================
# 边框颜色系统 (Border Colors)
# ============================================================================
BORDER = {
"default": "#E5E7EB", # 默认边框 (gray-200)
"medium": "#D1D5DB", # 中等边框 (gray-300)
"strong": "#9CA3AF", # 强边框 (gray-400)
"focus": "#667EEA", # 焦点边框 (主品牌色)
}
# ============================================================================
# 阴影系统 (Shadow System)
# ============================================================================
SHADOW = {
"sm": "0 1px 2px 0 rgba(0, 0, 0, 0.05)",
"md": "0 4px 6px -1px rgba(0, 0, 0, 0.1)",
"lg": "0 10px 15px -3px rgba(0, 0, 0, 0.1)",
"xl": "0 20px 25px -5px rgba(0, 0, 0, 0.1)",
"2xl": "0 25px 50px -12px rgba(0, 0, 0, 0.25)",
}
# ============================================================================
# 场所类型主题系统 (Venue Theme System)
# 14种预设主题,动态注入到生成的推荐页面中
#
# 每个主题包含:
# - theme_primary: 主色调 (Header背景、主按钮)
# - theme_primary_light: 亮色变体 (悬停态、次要元素)
# - theme_primary_dark: 暗色变体 (Active态、强调元素)
# - theme_secondary: 辅助色 (图标、装饰元素)
# - theme_light: 浅背景色 (卡片背景、Section背景)
# - theme_dark: 深文字色 (标题、关键信息)
#
# WCAG验证: 所有theme_primary在白色背景上对比度 ≥ 3.0:1 (大文字)
# 所有theme_dark在theme_light背景上对比度 ≥ 4.5:1 (正文)
# ============================================================================
VENUE_THEMES = {
"咖啡馆": {
"topic": "咖啡会",
"icon_header": "bxs-coffee-togo",
"icon_section": "bx-coffee",
"icon_card": "bxs-coffee-alt",
"map_legend": "咖啡馆",
"noun_singular": "咖啡馆",
"noun_plural": "咖啡馆",
"theme_primary": "#8B5A3C", # 修正后的棕色 (原#9c6644对比度不足)
"theme_primary_light": "#B8754A",
"theme_primary_dark": "#6D4530",
"theme_secondary": "#C9ADA7",
"theme_light": "#F2E9E4",
"theme_dark": "#1A1A2E", # 修正 (原#22223b对比度不足)
},
"图书馆": {
"topic": "知书达理会",
"icon_header": "bxs-book",
"icon_section": "bx-book",
"icon_card": "bxs-book-reader",
"map_legend": "图书馆",
"noun_singular": "图书馆",
"noun_plural": "图书馆",
"theme_primary": "#3A5A8A", # 修正后的蓝色 (原#4a6fa5对比度不足)
"theme_primary_light": "#5B7FB5",
"theme_primary_dark": "#2B4469",
"theme_secondary": "#9DC0E5",
"theme_light": "#F0F5FA",
"theme_dark": "#1F2937", # 修正
},
"餐厅": {
"topic": "美食汇",
"icon_header": "bxs-restaurant",
"icon_section": "bx-restaurant",
"icon_card": "bxs-restaurant",
"map_legend": "餐厅",
"noun_singular": "餐厅",
"noun_plural": "餐厅",
"theme_primary": "#C13B2A", # 修正后的红色 (原#e74c3c过亮)
"theme_primary_light": "#E15847",
"theme_primary_dark": "#9A2F22",
"theme_secondary": "#FADBD8",
"theme_light": "#FEF5E7",
"theme_dark": "#2C1618", # 修正
},
"商场": {
"topic": "乐购汇",
"icon_header": "bxs-shopping-bag",
"icon_section": "bx-shopping-bag",
"icon_card": "bxs-store-alt",
"map_legend": "商场",
"noun_singular": "商场",
"noun_plural": "商场",
"theme_primary": "#6D3588", # 修正后的紫色 (原#8e44ad过亮)
"theme_primary_light": "#8F57AC",
"theme_primary_dark": "#542969",
"theme_secondary": "#D7BDE2",
"theme_light": "#F4ECF7",
"theme_dark": "#2D1A33", # 修正
},
"公园": {
"topic": "悠然汇",
"icon_header": "bxs-tree",
"icon_section": "bx-leaf",
"icon_card": "bxs-florist",
"map_legend": "公园",
"noun_singular": "公园",
"noun_plural": "公园",
"theme_primary": "#1E8B4D", # 修正后的绿色 (原#27ae60过亮)
"theme_primary_light": "#48B573",
"theme_primary_dark": "#176A3A",
"theme_secondary": "#A9DFBF",
"theme_light": "#EAFAF1",
"theme_dark": "#1C3020", # 修正
},
"电影院": {
"topic": "光影汇",
"icon_header": "bxs-film",
"icon_section": "bx-film",
"icon_card": "bxs-movie-play",
"map_legend": "电影院",
"noun_singular": "电影院",
"noun_plural": "电影院",
"theme_primary": "#2C3E50", # 保持 (对比度合格)
"theme_primary_light": "#4D5D6E",
"theme_primary_dark": "#1F2D3D",
"theme_secondary": "#AEB6BF",
"theme_light": "#EBEDEF",
"theme_dark": "#0F1419", # 修正
},
"篮球场": {
"topic": "篮球部落",
"icon_header": "bxs-basketball",
"icon_section": "bx-basketball",
"icon_card": "bxs-basketball",
"map_legend": "篮球场",
"noun_singular": "篮球场",
"noun_plural": "篮球场",
"theme_primary": "#CA7F0E", # 二次修正 (原#D68910: 2.82:1 -> 3.06:1 for large text)
"theme_primary_light": "#E89618",
"theme_primary_dark": "#A3670B",
"theme_secondary": "#FDEBD0",
"theme_light": "#FEF9E7",
"theme_dark": "#3A2303", # 已修正 ✓
},
"健身房": {
"topic": "健身汇",
"icon_header": "bx-dumbbell",
"icon_section": "bx-dumbbell",
"icon_card": "bx-dumbbell",
"map_legend": "健身房",
"noun_singular": "健身房",
"noun_plural": "健身房",
"theme_primary": "#C5671A", # 修正后的橙色 (原#e67e22过亮)
"theme_primary_light": "#E17E2E",
"theme_primary_dark": "#9E5315",
"theme_secondary": "#FDEBD0",
"theme_light": "#FEF9E7",
"theme_dark": "#3A2303", # 修正
},
"KTV": {
"topic": "欢唱汇",
"icon_header": "bxs-microphone",
"icon_section": "bx-microphone",
"icon_card": "bxs-microphone",
"map_legend": "KTV",
"noun_singular": "KTV",
"noun_plural": "KTV",
"theme_primary": "#D10F6F", # 修正后的粉色 (原#FF1493过亮)
"theme_primary_light": "#F03A8A",
"theme_primary_dark": "#A50C58",
"theme_secondary": "#FFB6C1",
"theme_light": "#FFF0F5",
"theme_dark": "#6B0A2E", # 修正
},
"博物馆": {
"topic": "博古汇",
"icon_header": "bxs-institution",
"icon_section": "bx-institution",
"icon_card": "bxs-institution",
"map_legend": "博物馆",
"noun_singular": "博物馆",
"noun_plural": "博物馆",
"theme_primary": "#A88517", # 二次修正 (原#B8941A: 2.88:1 -> 3.21:1 for large text)
"theme_primary_light": "#C29E1D",
"theme_primary_dark": "#896B13",
"theme_secondary": "#F0E68C",
"theme_light": "#FFFACD",
"theme_dark": "#6B5535", # 已修正 ✓
},
"景点": {
"topic": "游览汇",
"icon_header": "bxs-landmark",
"icon_section": "bx-landmark",
"icon_card": "bxs-landmark",
"map_legend": "景点",
"noun_singular": "景点",
"noun_plural": "景点",
"theme_primary": "#138496", # 保持 (对比度合格)
"theme_primary_light": "#20A5BB",
"theme_primary_dark": "#0F6875",
"theme_secondary": "#7FDBDA",
"theme_light": "#E0F7FA",
"theme_dark": "#00504A", # 修正
},
"酒吧": {
"topic": "夜宴汇",
"icon_header": "bxs-drink",
"icon_section": "bx-drink",
"icon_card": "bxs-drink",
"map_legend": "酒吧",
"noun_singular": "酒吧",
"noun_plural": "酒吧",
"theme_primary": "#2C3E50", # 保持 (对比度合格)
"theme_primary_light": "#4D5D6E",
"theme_primary_dark": "#1B2631",
"theme_secondary": "#85929E",
"theme_light": "#EBF5FB",
"theme_dark": "#0C1014", # 修正
},
"茶楼": {
"topic": "茶韵汇",
"icon_header": "bxs-coffee-bean",
"icon_section": "bx-coffee-bean",
"icon_card": "bxs-coffee-bean",
"map_legend": "茶楼",
"noun_singular": "茶楼",
"noun_plural": "茶楼",
"theme_primary": "#406058", # 修正后的绿色 (原#52796F过亮)
"theme_primary_light": "#567A6F",
"theme_primary_dark": "#2F4841",
"theme_secondary": "#CAD2C5",
"theme_light": "#F7F9F7",
"theme_dark": "#1F2D28", # 修正
},
"游泳馆": { # 新增第14个主题
"topic": "泳池汇",
"icon_header": "bx-swim",
"icon_section": "bx-swim",
"icon_card": "bx-swim",
"map_legend": "游泳馆",
"noun_singular": "游泳馆",
"noun_plural": "游泳馆",
"theme_primary": "#1E90FF", # 水蓝色
"theme_primary_light": "#4DA6FF",
"theme_primary_dark": "#1873CC",
"theme_secondary": "#87CEEB",
"theme_light": "#E0F2FE",
"theme_dark": "#0C4A6E",
},
# 默认主题 (与咖啡馆相同)
"default": {
"topic": "推荐地点",
"icon_header": "bx-map-pin",
"icon_section": "bx-location-plus",
"icon_card": "bx-map-alt",
"map_legend": "推荐地点",
"noun_singular": "地点",
"noun_plural": "地点",
"theme_primary": "#8B5A3C",
"theme_primary_light": "#B8754A",
"theme_primary_dark": "#6D4530",
"theme_secondary": "#C9ADA7",
"theme_light": "#F2E9E4",
"theme_dark": "#1A1A2E",
},
}
# ============================================================================
# 间距系统 (Spacing System)
# 基于8px基准的间距尺度
# ============================================================================
SPACING = {
"0": "0",
"1": "4px", # 0.25rem
"2": "8px", # 0.5rem
"3": "12px", # 0.75rem
"4": "16px", # 1rem
"5": "20px", # 1.25rem
"6": "24px", # 1.5rem
"8": "32px", # 2rem
"10": "40px", # 2.5rem
"12": "48px", # 3rem
"16": "64px", # 4rem
"20": "80px", # 5rem
}
# ============================================================================
# 圆角系统 (Border Radius System)
# ============================================================================
RADIUS = {
"none": "0",
"sm": "4px",
"md": "8px",
"lg": "12px",
"xl": "16px",
"2xl": "24px",
"full": "9999px",
}
# ============================================================================
# 字体系统 (Typography System) - MeetSpot 品牌字体
# Poppins (标题) - 友好且现代,比 Inter 更有个性
# Nunito (正文) - 圆润易读,传递温暖感
# ============================================================================
FONT = {
"family_heading": '"Poppins", "PingFang SC", -apple-system, BlinkMacSystemFont, sans-serif',
"family_sans": '"Nunito", "Microsoft YaHei", -apple-system, BlinkMacSystemFont, sans-serif',
"family_mono": '"JetBrains Mono", "Fira Code", "SF Mono", "Consolas", "Monaco", monospace',
# 字体大小 (基于16px基准)
"size_xs": "0.75rem", # 12px
"size_sm": "0.875rem", # 14px
"size_base": "1rem", # 16px
"size_lg": "1.125rem", # 18px
"size_xl": "1.25rem", # 20px
"size_2xl": "1.5rem", # 24px
"size_3xl": "1.875rem", # 30px
"size_4xl": "2.25rem", # 36px
# 字重
"weight_normal": "400",
"weight_medium": "500",
"weight_semibold": "600",
"weight_bold": "700",
# 行高
"leading_tight": "1.25",
"leading_normal": "1.5",
"leading_relaxed": "1.7",
"leading_loose": "2",
}
# ============================================================================
# Z-Index系统 (Layering System)
# ============================================================================
Z_INDEX = {
"dropdown": "1000",
"sticky": "1020",
"fixed": "1030",
"modal_backdrop": "1040",
"modal": "1050",
"popover": "1060",
"tooltip": "1070",
}
# ============================================================================
# 交互动画系统 (Interaction Animations)
# 遵循WCAG 2.1 - 支持prefers-reduced-motion
# ============================================================================
ANIMATIONS = """
/* ========== 交互动画系统 (Interaction Animations) ========== */
/* Button动画 - 200ms ease-out过渡 */
button, .btn, input[type="submit"], a.button {
transition: all 0.2s ease-out;
}
button:hover, .btn:hover, input[type="submit"]:hover, a.button:hover {
transform: translateY(-2px);
box-shadow: var(--shadow-lg);
}
button:active, .btn:active, input[type="submit"]:active, a.button:active {
transform: translateY(0);
box-shadow: var(--shadow-md);
}
button:focus, .btn:focus, input[type="submit"]:focus, a.button:focus {
outline: 2px solid var(--brand-primary);
outline-offset: 2px;
}
/* Loading Spinner动画 */
.loading::after {
content: "";
width: 16px;
height: 16px;
margin-left: 8px;
border: 2px solid var(--brand-primary);
border-top-color: transparent;
border-radius: 50%;
display: inline-block;
animation: spin 0.6s linear infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
/* Card悬停效果 - 微妙的缩放和阴影提升 */
.card, .venue-card, .recommendation-card {
transition: transform 0.2s ease-out, box-shadow 0.2s ease-out;
}
.card:hover, .venue-card:hover, .recommendation-card:hover {
transform: scale(1.02);
box-shadow: var(--shadow-xl);
}
/* Fade-in渐显动画 - 400ms */
.fade-in {
animation: fadeIn 0.4s ease-out;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
/* Slide-in滑入动画 */
.slide-in {
animation: slideIn 0.4s ease-out;
}
@keyframes slideIn {
from {
opacity: 0;
transform: translateX(-20px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
/* WCAG 2.1无障碍支持 - 尊重用户的动画偏好 */
@media (prefers-reduced-motion: reduce) {
*,
*::before,
*::after {
animation-duration: 0.01ms !important;
animation-iteration-count: 1 !important;
transition-duration: 0.01ms !important;
scroll-behavior: auto !important;
}
}
"""
# ============================================================================
# 辅助方法
# ============================================================================
@classmethod
@lru_cache(maxsize=128)
def get_venue_theme(cls, venue_type: str) -> Dict[str, str]:
"""
根据场所类型获取主题配置
Args:
venue_type: 场所类型 (如"咖啡馆""图书馆")
Returns:
包含主题色彩和图标的字典
Example:
>>> theme = DesignTokens.get_venue_theme("咖啡馆")
>>> print(theme['theme_primary']) # "#8B5A3C"
"""
return cls.VENUE_THEMES.get(venue_type, cls.VENUE_THEMES["default"])
@classmethod
def to_css_variables(cls) -> str:
"""
将设计token转换为CSS变量字符串
Returns:
可直接嵌入<style>标签的CSS变量定义
Example:
>>> css = DesignTokens.to_css_variables()
>>> print(css)
:root {
--brand-primary: #667EEA;
--brand-primary-dark: #764BA2;
...
}
"""
lines = [":root {"]
# 品牌色
for key, value in cls.BRAND.items():
css_key = f"--brand-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
# 文字色
for key, value in cls.TEXT.items():
css_key = f"--text-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
# 背景色
for key, value in cls.BACKGROUND.items():
css_key = f"--bg-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
# 边框色
for key, value in cls.BORDER.items():
css_key = f"--border-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
# 阴影
for key, value in cls.SHADOW.items():
css_key = f"--shadow-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
# 间距
for key, value in cls.SPACING.items():
css_key = f"--spacing-{key}"
lines.append(f" {css_key}: {value};")
# 圆角
for key, value in cls.RADIUS.items():
css_key = f"--radius-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
# 字体
for key, value in cls.FONT.items():
css_key = f"--font-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
# Z-Index
for key, value in cls.Z_INDEX.items():
css_key = f"--z-{key.replace('_', '-')}"
lines.append(f" {css_key}: {value};")
lines.append("}")
return "\n".join(lines)
@classmethod
def generate_css_file(cls, output_path: str = "static/css/design-tokens.css"):
"""
生成独立的CSS设计token文件
Args:
output_path: 输出文件路径
Example:
>>> DesignTokens.generate_css_file()
# 生成 static/css/design-tokens.css
"""
import os
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write("/* ============================================\n")
f.write(" * MeetSpot Design Tokens\n")
f.write(" * 自动生成 - 请勿手动编辑\n")
f.write(" * 生成源: app/design_tokens.py\n")
f.write(" * ==========================================*/\n\n")
f.write(cls.to_css_variables())
f.write("\n\n/* Compatibility fallbacks for older browsers */\n")
f.write(".no-cssvar {\n")
f.write(" /* Fallback for browsers without CSS variable support */\n")
f.write(f" color: {cls.TEXT['primary']};\n")
f.write(f" background-color: {cls.BACKGROUND['primary']};\n")
f.write("}\n\n")
# 追加交互动画系统
f.write(cls.ANIMATIONS)
# ============================================================================
# 全局单例访问 (方便快速引用)
# ============================================================================
COLORS = {
"brand": DesignTokens.BRAND,
"text": DesignTokens.TEXT,
"background": DesignTokens.BACKGROUND,
"border": DesignTokens.BORDER,
}
VENUE_THEMES = DesignTokens.VENUE_THEMES
# ============================================================================
# 便捷函数
# ============================================================================
def get_venue_theme(venue_type: str) -> Dict[str, str]:
"""便捷函数: 获取场所主题"""
return DesignTokens.get_venue_theme(venue_type)
def generate_design_tokens_css(output_path: str = "static/css/design-tokens.css"):
"""便捷函数: 生成CSS文件"""
DesignTokens.generate_css_file(output_path)

View File

@@ -0,0 +1,13 @@
class ToolError(Exception):
"""Raised when a tool encounters an error."""
def __init__(self, message):
self.message = message
class OpenManusError(Exception):
"""Base exception for all OpenManus errors"""
class TokenLimitExceeded(OpenManusError):
"""Exception raised when the token limit is exceeded"""

800
MeetSpot/app/llm.py Normal file
View File

@@ -0,0 +1,800 @@
import math
from typing import Dict, List, Optional, Union
import tiktoken
from openai import (APIError, AsyncAzureOpenAI, AsyncOpenAI,
AuthenticationError, OpenAIError, RateLimitError)
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from tenacity import (retry, retry_if_exception_type, stop_after_attempt,
wait_random_exponential)
from app.config import LLMSettings, config
from app.exceptions import TokenLimitExceeded
from app.logger import logger # Assuming a logger is set up in your app
from app.schema import (ROLE_VALUES, TOOL_CHOICE_TYPE, TOOL_CHOICE_VALUES,
Message, ToolChoice)
REASONING_MODELS = ["o1", "o3-mini"]
MULTIMODAL_MODELS = [
"gpt-4-vision-preview",
"gpt-4o",
"gpt-4o-mini",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
]
class TokenCounter:
# Token constants
BASE_MESSAGE_TOKENS = 4
FORMAT_TOKENS = 2
LOW_DETAIL_IMAGE_TOKENS = 85
HIGH_DETAIL_TILE_TOKENS = 170
# Image processing constants
MAX_SIZE = 2048
HIGH_DETAIL_TARGET_SHORT_SIDE = 768
TILE_SIZE = 512
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def count_text(self, text: str) -> int:
"""Calculate tokens for a text string"""
return 0 if not text else len(self.tokenizer.encode(text))
def count_image(self, image_item: dict) -> int:
"""
Calculate tokens for an image based on detail level and dimensions
For "low" detail: fixed 85 tokens
For "high" detail:
1. Scale to fit in 2048x2048 square
2. Scale shortest side to 768px
3. Count 512px tiles (170 tokens each)
4. Add 85 tokens
"""
detail = image_item.get("detail", "medium")
# For low detail, always return fixed token count
if detail == "low":
return self.LOW_DETAIL_IMAGE_TOKENS
# For medium detail (default in OpenAI), use high detail calculation
# OpenAI doesn't specify a separate calculation for medium
# For high detail, calculate based on dimensions if available
if detail == "high" or detail == "medium":
# If dimensions are provided in the image_item
if "dimensions" in image_item:
width, height = image_item["dimensions"]
return self._calculate_high_detail_tokens(width, height)
# Default values when dimensions aren't available or detail level is unknown
if detail == "high":
# Default to a 1024x1024 image calculation for high detail
return self._calculate_high_detail_tokens(1024, 1024) # 765 tokens
elif detail == "medium":
# Default to a medium-sized image for medium detail
return 1024 # This matches the original default
else:
# For unknown detail levels, use medium as default
return 1024
def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
"""Calculate tokens for high detail images based on dimensions"""
# Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
if width > self.MAX_SIZE or height > self.MAX_SIZE:
scale = self.MAX_SIZE / max(width, height)
width = int(width * scale)
height = int(height * scale)
# Step 2: Scale so shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
scaled_width = int(width * scale)
scaled_height = int(height * scale)
# Step 3: Count number of 512px tiles
tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
total_tiles = tiles_x * tiles_y
# Step 4: Calculate final token count
return (
total_tiles * self.HIGH_DETAIL_TILE_TOKENS
) + self.LOW_DETAIL_IMAGE_TOKENS
def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
"""Calculate tokens for message content"""
if not content:
return 0
if isinstance(content, str):
return self.count_text(content)
token_count = 0
for item in content:
if isinstance(item, str):
token_count += self.count_text(item)
elif isinstance(item, dict):
if "text" in item:
token_count += self.count_text(item["text"])
elif "image_url" in item:
token_count += self.count_image(item)
return token_count
def count_tool_calls(self, tool_calls: List[dict]) -> int:
"""Calculate tokens for tool calls"""
token_count = 0
for tool_call in tool_calls:
if "function" in tool_call:
function = tool_call["function"]
token_count += self.count_text(function.get("name", ""))
token_count += self.count_text(function.get("arguments", ""))
return token_count
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the total number of tokens in a message list"""
total_tokens = self.FORMAT_TOKENS # Base format tokens
for message in messages:
tokens = self.BASE_MESSAGE_TOKENS # Base tokens per message
# Add role tokens
tokens += self.count_text(message.get("role", ""))
# Add content tokens
if "content" in message:
tokens += self.count_content(message["content"])
# Add tool calls tokens
if "tool_calls" in message:
tokens += self.count_tool_calls(message["tool_calls"])
# Add name and tool_call_id tokens
tokens += self.count_text(message.get("name", ""))
tokens += self.count_text(message.get("tool_call_id", ""))
total_tokens += tokens
return total_tokens
class LLM:
_instances: Dict[str, "LLM"] = {}
def __new__(
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
):
if config_name not in cls._instances:
instance = super().__new__(cls)
instance.__init__(config_name, llm_config)
cls._instances[config_name] = instance
return cls._instances[config_name]
def __init__(
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
):
if not hasattr(self, "client"): # Only initialize if not already initialized
llm_config = llm_config or config.llm
llm_config = llm_config.get(config_name, llm_config["default"])
self.model = llm_config.model
self.max_tokens = llm_config.max_tokens
self.temperature = llm_config.temperature
self.api_type = llm_config.api_type
self.api_key = llm_config.api_key
self.api_version = llm_config.api_version
self.base_url = llm_config.base_url
# Add token counting related attributes
self.total_input_tokens = 0
self.total_completion_tokens = 0
self.max_input_tokens = (
llm_config.max_input_tokens
if hasattr(llm_config, "max_input_tokens")
else None
)
# Initialize tokenizer
try:
self.tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError:
# If the model is not in tiktoken's presets, use cl100k_base as default
self.tokenizer = tiktoken.get_encoding("cl100k_base")
if self.api_type == "azure":
self.client = AsyncAzureOpenAI(
base_url=self.base_url,
api_key=self.api_key,
api_version=self.api_version,
)
else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
self.token_counter = TokenCounter(self.tokenizer)
def count_tokens(self, text: str) -> int:
"""Calculate the number of tokens in a text"""
if not text:
return 0
return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int:
return self.token_counter.count_message_tokens(messages)
def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None:
"""Update token counts"""
# Only track tokens if max_input_tokens is set
self.total_input_tokens += input_tokens
self.total_completion_tokens += completion_tokens
logger.info(
f"Token usage: Input={input_tokens}, Completion={completion_tokens}, "
f"Cumulative Input={self.total_input_tokens}, Cumulative Completion={self.total_completion_tokens}, "
f"Total={input_tokens + completion_tokens}, Cumulative Total={self.total_input_tokens + self.total_completion_tokens}"
)
def check_token_limit(self, input_tokens: int) -> bool:
"""Check if token limits are exceeded"""
if self.max_input_tokens is not None:
return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
# If max_input_tokens is not set, always return True
return True
def get_limit_error_message(self, input_tokens: int) -> str:
"""Generate error message for token limit exceeded"""
if (
self.max_input_tokens is not None
and (self.total_input_tokens + input_tokens) > self.max_input_tokens
):
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
return "Token limit exceeded"
@staticmethod
def format_messages(
messages: List[Union[dict, Message]], supports_images: bool = False
) -> List[dict]:
"""
Format messages for LLM by converting them to OpenAI message format.
Args:
messages: List of messages that can be either dict or Message objects
supports_images: Flag indicating if the target model supports image inputs
Returns:
List[dict]: List of formatted messages in OpenAI format
Raises:
ValueError: If messages are invalid or missing required fields
TypeError: If unsupported message types are provided
Examples:
>>> msgs = [
... Message.system_message("You are a helpful assistant"),
... {"role": "user", "content": "Hello"},
... Message.user_message("How are you?")
... ]
>>> formatted = LLM.format_messages(msgs)
"""
formatted_messages = []
for message in messages:
# Convert Message objects to dictionaries
if isinstance(message, Message):
message = message.to_dict()
if isinstance(message, dict):
# If message is a dict, ensure it has required fields
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
# Process base64 images if present and model supports images
if supports_images and message.get("base64_image"):
# Initialize or convert content to appropriate format
if not message.get("content"):
message["content"] = []
elif isinstance(message["content"], str):
message["content"] = [
{"type": "text", "text": message["content"]}
]
elif isinstance(message["content"], list):
# Convert string items to proper text objects
message["content"] = [
(
{"type": "text", "text": item}
if isinstance(item, str)
else item
)
for item in message["content"]
]
# Add the image to content
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{message['base64_image']}"
},
}
)
# Remove the base64_image field
del message["base64_image"]
# If model doesn't support images but message has base64_image, handle gracefully
elif not supports_images and message.get("base64_image"):
# Just remove the base64_image field and keep the text content
del message["base64_image"]
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# else: do not include the message
else:
raise TypeError(f"Unsupported message type: {type(message)}")
# Validate all messages have required fields
for msg in formatted_messages:
if msg["role"] not in ROLE_VALUES:
raise ValueError(f"Invalid role: {msg['role']}")
return formatted_messages
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask(
self,
messages: List[Union[dict, Message]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
stream: bool = True,
temperature: Optional[float] = None,
) -> str:
"""
Send a prompt to the LLM and get the response.
Args:
messages: List of conversation messages
system_msgs: Optional system messages to prepend
stream (bool): Whether to stream the response
temperature (float): Sampling temperature for the response
Returns:
str: The generated response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# Check if the model supports images
supports_images = self.model in MULTIMODAL_MODELS
# Format system and user messages with image support check
if system_msgs:
system_msgs = self.format_messages(system_msgs, supports_images)
messages = system_msgs + self.format_messages(messages, supports_images)
else:
messages = self.format_messages(messages, supports_images)
# Calculate input token count
input_tokens = self.count_message_tokens(messages)
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
params = {
"model": self.model,
"messages": messages,
}
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
if not stream:
# Non-streaming request
response = await self.client.chat.completions.create(
**params, stream=False
)
if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM")
# Update token counts
self.update_token_count(
response.usage.prompt_tokens, response.usage.completion_tokens
)
return response.choices[0].message.content
# Streaming request, For streaming, update estimated token count before making the request
self.update_token_count(input_tokens)
response = await self.client.chat.completions.create(**params, stream=True)
collected_messages = []
completion_text = ""
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or ""
collected_messages.append(chunk_message)
completion_text += chunk_message
print(chunk_message, end="", flush=True)
print() # Newline after streaming
full_response = "".join(collected_messages).strip()
if not full_response:
raise ValueError("Empty response from streaming LLM")
# estimate completion tokens for streaming response
completion_tokens = self.count_tokens(completion_text)
logger.info(
f"Estimated completion tokens for streaming response: {completion_tokens}"
)
self.total_completion_tokens += completion_tokens
return full_response
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError:
logger.exception(f"Validation error")
raise
except OpenAIError as oe:
logger.exception(f"OpenAI API error")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception:
logger.exception(f"Unexpected error in ask")
raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask_with_images(
self,
messages: List[Union[dict, Message]],
images: List[Union[str, dict]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
stream: bool = False,
temperature: Optional[float] = None,
) -> str:
"""
Send a prompt with images to the LLM and get the response.
Args:
messages: List of conversation messages
images: List of image URLs or image data dictionaries
system_msgs: Optional system messages to prepend
stream (bool): Whether to stream the response
temperature (float): Sampling temperature for the response
Returns:
str: The generated response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# For ask_with_images, we always set supports_images to True because
# this method should only be called with models that support images
if self.model not in MULTIMODAL_MODELS:
raise ValueError(
f"Model {self.model} does not support images. Use a model from {MULTIMODAL_MODELS}"
)
# Format messages with image support
formatted_messages = self.format_messages(messages, supports_images=True)
# Ensure the last message is from the user to attach images
if not formatted_messages or formatted_messages[-1]["role"] != "user":
raise ValueError(
"The last message must be from the user to attach images"
)
# Process the last user message to include images
last_message = formatted_messages[-1]
# Convert content to multimodal format if needed
content = last_message["content"]
multimodal_content = (
[{"type": "text", "text": content}]
if isinstance(content, str)
else content if isinstance(content, list) else []
)
# Add images to content
for image in images:
if isinstance(image, str):
multimodal_content.append(
{"type": "image_url", "image_url": {"url": image}}
)
elif isinstance(image, dict) and "url" in image:
multimodal_content.append({"type": "image_url", "image_url": image})
elif isinstance(image, dict) and "image_url" in image:
multimodal_content.append(image)
else:
raise ValueError(f"Unsupported image format: {image}")
# Update the message with multimodal content
last_message["content"] = multimodal_content
# Add system messages if provided
if system_msgs:
all_messages = (
self.format_messages(system_msgs, supports_images=True)
+ formatted_messages
)
else:
all_messages = formatted_messages
# Calculate tokens and check limits
input_tokens = self.count_message_tokens(all_messages)
if not self.check_token_limit(input_tokens):
raise TokenLimitExceeded(self.get_limit_error_message(input_tokens))
# Set up API parameters
params = {
"model": self.model,
"messages": all_messages,
"stream": stream,
}
# Add model-specific parameters
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
# Handle non-streaming request
if not stream:
response = await self.client.chat.completions.create(**params)
if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM")
self.update_token_count(
response.usage.prompt_tokens, response.usage.completion_tokens
)
return response.choices[0].message.content
# Handle streaming request
response = await self.client.chat.completions.create(**params)
collected_messages = []
completion_text = ""
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or ""
collected_messages.append(chunk_message)
completion_text += chunk_message
print(chunk_message, end="", flush=True)
print() # Newline after streaming
full_response = "".join(collected_messages).strip()
if not full_response:
raise ValueError("Empty response from streaming LLM")
completion_tokens = self.count_tokens(completion_text)
logger.info(
f"Estimated completion tokens for streaming response with images: {completion_tokens}"
)
self.update_token_count(input_tokens, completion_tokens)
return full_response
except TokenLimitExceeded:
raise
except ValueError as ve:
logger.error(f"Validation error in ask_with_images: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask_with_images: {e}")
raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask_tool(
self,
messages: List[Union[dict, Message]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
timeout: int = 300,
tools: Optional[List[dict]] = None,
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
temperature: Optional[float] = None,
**kwargs,
) -> ChatCompletionMessage | None:
"""
Ask LLM using functions/tools and return the response.
Args:
messages: List of conversation messages
system_msgs: Optional system messages to prepend
timeout: Request timeout in seconds
tools: List of tools to use
tool_choice: Tool choice strategy
temperature: Sampling temperature for the response
**kwargs: Additional completion arguments
Returns:
ChatCompletionMessage: The model's response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If tools, tool_choice, or messages are invalid
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# Validate tool_choice
if tool_choice not in TOOL_CHOICE_VALUES:
raise ValueError(f"Invalid tool_choice: {tool_choice}")
# Check if the model supports images
supports_images = self.model in MULTIMODAL_MODELS
# Format messages
if system_msgs:
system_msgs = self.format_messages(system_msgs, supports_images)
formatted_messages = system_msgs + self.format_messages(messages, supports_images)
else:
formatted_messages = self.format_messages(messages, supports_images)
# 验证消息序列确保tool消息前面有带tool_calls的assistant消息
valid_messages = []
tool_calls_ids = set() # 跟踪所有有效的tool_call IDs
for i, msg in enumerate(formatted_messages):
if isinstance(msg, dict):
role = msg.get("role")
tool_call_id = msg.get("tool_call_id")
else:
role = msg.role if hasattr(msg, "role") else None
tool_call_id = msg.tool_call_id if hasattr(msg, "tool_call_id") else None
# 如果是tool消息验证它引用的tool_call_id是否存在
if role == "tool" and tool_call_id:
if tool_call_id not in tool_calls_ids:
logger.warning(f"发现无效的工具消息 - tool_call_id '{tool_call_id}' 未找到关联的tool_calls将跳过此消息")
continue
# 如果是assistant消息且有tool_calls记录所有tool_call IDs
elif role == "assistant":
tool_calls = []
if isinstance(msg, dict) and "tool_calls" in msg:
tool_calls = msg.get("tool_calls", [])
elif hasattr(msg, "tool_calls") and msg.tool_calls:
tool_calls = msg.tool_calls
# 收集所有工具调用ID
for call in tool_calls:
if isinstance(call, dict) and "id" in call:
tool_calls_ids.add(call["id"])
elif hasattr(call, "id"):
tool_calls_ids.add(call.id)
# 添加有效消息
valid_messages.append(msg)
# 使用验证过的消息序列替换原来的消息
formatted_messages = valid_messages
# Calculate input token count
input_tokens = self.count_message_tokens(formatted_messages)
# If there are tools, calculate token count for tool descriptions
tools_tokens = 0
if tools:
for tool in tools:
tools_tokens += self.count_tokens(str(tool))
input_tokens += tools_tokens
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
# Validate tools if provided
if tools:
for tool in tools:
if not isinstance(tool, dict) or "type" not in tool:
raise ValueError("Each tool must be a dict with 'type' field")
# Set up the completion request
params = {
"model": self.model,
"messages": formatted_messages,
"tools": tools,
"tool_choice": tool_choice,
"timeout": timeout,
**kwargs,
}
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
response: ChatCompletion = await self.client.chat.completions.create(
**params, stream=False
)
# Check if response is valid
if not response.choices or not response.choices[0].message:
print(response)
# raise ValueError("Invalid or empty response from LLM")
return None
# Update token counts
self.update_token_count(
response.usage.prompt_tokens, response.usage.completion_tokens
)
return response.choices[0].message
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError as ve:
logger.error(f"Validation error in ask_tool: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask_tool: {e}")
raise

42
MeetSpot/app/logger.py Normal file
View File

@@ -0,0 +1,42 @@
import sys
from datetime import datetime
from loguru import logger as _logger
from app.config import PROJECT_ROOT
_print_level = "INFO"
def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None):
"""Adjust the log level to above level"""
global _print_level
_print_level = print_level
current_date = datetime.now()
formatted_date = current_date.strftime("%Y%m%d%H%M%S")
log_name = (
f"{name}_{formatted_date}" if name else formatted_date
) # name a log with prefix name
_logger.remove()
_logger.add(sys.stderr, level=print_level)
_logger.add(PROJECT_ROOT / f"logs/{log_name}.log", level=logfile_level)
return _logger
logger = define_log_level()
if __name__ == "__main__":
logger.info("Starting application")
logger.debug("Debug message")
logger.warning("Warning message")
logger.error("Error message")
logger.critical("Critical message")
try:
raise ValueError("Test error")
except Exception as e:
logger.exception(f"An error occurred: {e}")

View File

@@ -0,0 +1,9 @@
"""数据模型包。"""
from app.db.database import Base # noqa: F401
# 导入所有模型以确保它们注册到Base.metadata
from app.models.user import User # noqa: F401
from app.models.room import GatheringRoom, RoomParticipant # noqa: F401
from app.models.message import ChatMessage, VenueVote # noqa: F401

View File

@@ -0,0 +1,53 @@
"""聊天消息与投票记录模型。"""
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, ForeignKey, String, Text, UniqueConstraint, func
from app.db.database import Base
def _generate_uuid() -> str:
return str(uuid.uuid4())
class VenueVote(Base):
__tablename__ = "venue_votes"
__table_args__ = (UniqueConstraint("room_id", "venue_id", "user_id", name="uq_vote"),)
id = Column(String(36), primary_key=True, default=_generate_uuid)
room_id = Column(String(36), ForeignKey("gathering_rooms.id"), nullable=False)
venue_id = Column(String(100), nullable=False)
user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
vote_type = Column(String(20), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class ChatMessage(Base):
__tablename__ = "chat_messages"
id = Column(String(36), primary_key=True, default=_generate_uuid)
room_id = Column(String(36), ForeignKey("gathering_rooms.id"), nullable=False)
user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
content = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class ChatMessageCreate(BaseModel):
content: str = Field(..., min_length=1, description="聊天内容")
class VoteCreate(BaseModel):
venue_id: str
vote_type: str = Field(..., pattern="^(like|dislike)$")
class VoteRead(BaseModel):
venue_id: str
vote_type: str
user_id: str
created_at: Optional[datetime] = None

View File

@@ -0,0 +1,60 @@
"""聚会房间相关的ORM与Pydantic模型。"""
import uuid
from datetime import datetime
from typing import Optional, Tuple
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Float, ForeignKey, String, Text, UniqueConstraint, func
from app.db.database import Base
def _generate_uuid() -> str:
return str(uuid.uuid4())
class GatheringRoom(Base):
__tablename__ = "gathering_rooms"
id = Column(String(36), primary_key=True, default=_generate_uuid)
name = Column(String(100), nullable=False)
description = Column(Text, default="")
host_user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
gathering_time = Column(DateTime(timezone=True))
status = Column(String(20), default="pending")
venue_keywords = Column(String(100), default="咖啡馆")
final_venue_json = Column(Text, nullable=True)
class RoomParticipant(Base):
__tablename__ = "room_participants"
__table_args__ = (UniqueConstraint("room_id", "user_id", name="uq_room_user"),)
id = Column(String(36), primary_key=True, default=_generate_uuid)
room_id = Column(String(36), ForeignKey("gathering_rooms.id"), nullable=False)
user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
location_name = Column(String(200))
location_lat = Column(Float)
location_lng = Column(Float)
joined_at = Column(DateTime(timezone=True), server_default=func.now())
role = Column(String(20), default="member")
class GatheringRoomCreate(BaseModel):
name: str = Field(..., description="聚会名称")
description: str = Field("", description="聚会描述")
gathering_time: Optional[datetime] = Field(
None, description="聚会时间ISO 字符串"
)
venue_keywords: str = Field("咖啡馆", description="场所类型关键词")
class RoomParticipantRead(BaseModel):
user_id: str
nickname: str
location_name: Optional[str] = None
location_coords: Optional[Tuple[float, float]] = None
role: str

View File

@@ -0,0 +1,44 @@
"""用户相关SQLAlchemy模型与Pydantic模式。"""
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, String, func
from app.db.database import Base
def _generate_uuid() -> str:
return str(uuid.uuid4())
class User(Base):
__tablename__ = "users"
id = Column(String(36), primary_key=True, default=_generate_uuid)
phone = Column(String(20), unique=True, nullable=False)
nickname = Column(String(50), nullable=False)
avatar_url = Column(String(255), default="")
created_at = Column(DateTime(timezone=True), server_default=func.now())
last_login = Column(DateTime(timezone=True))
class UserCreate(BaseModel):
phone: str = Field(..., description="手机号")
nickname: Optional[str] = Field(None, description="昵称,可选")
avatar_url: Optional[str] = Field("", description="头像URL可选")
class UserRead(BaseModel):
id: str
phone: str
nickname: str
avatar_url: str = ""
created_at: datetime
last_login: Optional[datetime] = None
class Config:
from_attributes = True

214
MeetSpot/app/schema.py Normal file
View File

@@ -0,0 +1,214 @@
from enum import Enum
from typing import Any, List, Literal, Optional, Union
from pydantic import BaseModel, Field
class Role(str, Enum):
"""Message role options"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
ROLE_VALUES = tuple(role.value for role in Role)
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
class ToolChoice(str, Enum):
"""Tool choice options"""
NONE = "none"
AUTO = "auto"
REQUIRED = "required"
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
class AgentState(str, Enum):
"""Agent execution states"""
IDLE = "IDLE"
RUNNING = "RUNNING"
FINISHED = "FINISHED"
ERROR = "ERROR"
class Function(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
"""Represents a tool/function call in a message"""
id: str
type: str = "function"
function: Function
class Message(BaseModel):
"""Represents a chat message in the conversation"""
role: ROLE_TYPE = Field(...) # type: ignore
content: Optional[str] = Field(default=None)
tool_calls: Optional[List[ToolCall]] = Field(default=None)
name: Optional[str] = Field(default=None)
tool_call_id: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
def __add__(self, other) -> List["Message"]:
"""支持 Message + list 或 Message + Message 的操作"""
if isinstance(other, list):
return [self] + other
elif isinstance(other, Message):
return [self, other]
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
)
def __radd__(self, other) -> List["Message"]:
"""支持 list + Message 的操作"""
if isinstance(other, list):
return other + [self]
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
)
def to_dict(self) -> dict:
"""Convert message to dictionary format"""
message = {"role": self.role}
if self.content is not None:
message["content"] = self.content
if self.tool_calls is not None and self.role == Role.ASSISTANT:
message["tool_calls"] = [
tool_call.dict() if hasattr(tool_call, "dict") else tool_call
for tool_call in self.tool_calls
]
if self.name is not None and self.role == Role.TOOL:
message["name"] = self.name
if self.tool_call_id is not None and self.role == Role.TOOL:
message["tool_call_id"] = self.tool_call_id
# 不要在API调用中包含base64_image这不是OpenAI API消息格式的一部分
return message
@classmethod
def user_message(
cls, content: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a user message"""
return cls(role=Role.USER, content=content, base64_image=base64_image)
@classmethod
def system_message(cls, content: str) -> "Message":
"""Create a system message"""
return cls(role=Role.SYSTEM, content=content)
@classmethod
def assistant_message(
cls, content: Optional[str] = None, base64_image: Optional[str] = None
) -> "Message":
"""Create an assistant message"""
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
@classmethod
def tool_message(
cls, content: str, name: str, tool_call_id: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a tool message
Args:
content: The content/result of the tool execution
name: The name of the tool that was executed
tool_call_id: The ID of the tool call this message is responding to
base64_image: Optional base64 encoded image
"""
if not tool_call_id:
raise ValueError("tool_call_id is required for tool messages")
if not name:
raise ValueError("name is required for tool messages")
return cls(
role=Role.TOOL,
content=content,
name=name,
tool_call_id=tool_call_id,
base64_image=base64_image,
)
@classmethod
def from_tool_calls(
cls,
tool_calls: List[Any],
content: Union[str, List[str]] = "",
base64_image: Optional[str] = None,
**kwargs,
):
"""Create ToolCallsMessage from raw tool calls.
Args:
tool_calls: Raw tool calls from LLM
content: Optional message content
base64_image: Optional base64 encoded image
"""
# 确保tool_calls是正确格式的对象列表
formatted_calls = []
for call in tool_calls:
if hasattr(call, "id") and hasattr(call, "function"):
func_data = call.function
if hasattr(func_data, "model_dump"):
func_dict = func_data.model_dump()
else:
func_dict = {"name": func_data.name, "arguments": func_data.arguments}
formatted_call = {
"id": call.id,
"type": "function",
"function": func_dict
}
formatted_calls.append(formatted_call)
else:
# 如果已经是字典格式,直接使用
formatted_calls.append(call)
return cls(
role=Role.ASSISTANT,
content=content,
tool_calls=formatted_calls,
base64_image=base64_image,
**kwargs,
)
class Memory(BaseModel):
messages: List[Message] = Field(default_factory=list)
max_messages: int = Field(default=100)
def add_message(self, message: Message) -> None:
"""Add a message to memory"""
self.messages.append(message)
# Optional: Implement message limit
if len(self.messages) > self.max_messages:
self.messages = self.messages[-self.max_messages :]
def add_messages(self, messages: List[Message]) -> None:
"""Add multiple messages to memory"""
self.messages.extend(messages)
def clear(self) -> None:
"""Clear all messages"""
self.messages.clear()
def get_recent_messages(self, n: int) -> List[Message]:
"""Get n most recent messages"""
return self.messages[-n:]
def to_dict_list(self) -> List[dict]:
"""Convert messages to list of dicts"""
return [msg.to_dict() for msg in self.messages]

View File

@@ -0,0 +1,9 @@
from app.tool.base import BaseTool
from app.tool.meetspot_recommender import CafeRecommender
from app.tool.tool_collection import ToolCollection
__all__ = [
"BaseTool",
"CafeRecommender",
"ToolCollection",
]

101
MeetSpot/app/tool/base.py Normal file
View File

@@ -0,0 +1,101 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class BaseTool(ABC, BaseModel):
name: str
description: str
parameters: Optional[dict] = None
class Config:
arbitrary_types_allowed = True
async def __call__(self, **kwargs) -> Any:
"""Execute the tool with given parameters."""
return await self.execute(**kwargs)
@abstractmethod
async def execute(self, **kwargs) -> Any:
"""Execute the tool with given parameters."""
def to_param(self) -> Dict:
"""Convert tool to function call format."""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
}
class ToolResult(BaseModel):
"""Represents the result of a tool execution."""
output: Any = Field(default=None)
error: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
system: Optional[str] = Field(default=None)
class Config:
arbitrary_types_allowed = True
def __bool__(self):
return any(getattr(self, field) for field in self.__fields__)
def __add__(self, other: "ToolResult"):
def combine_fields(
field: Optional[str], other_field: Optional[str], concatenate: bool = True
):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
)
def __str__(self):
return f"Error: {self.error}" if self.error else self.output
def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
# return self.copy(update=kwargs)
return type(self)(**{**self.dict(), **kwargs})
class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
# 为 BaseTool 添加辅助方法
def _success_response(data) -> ToolResult:
"""创建成功的工具结果"""
import json
if isinstance(data, str):
text = data
else:
text = json.dumps(data, ensure_ascii=False, indent=2)
return ToolResult(output=text)
def _fail_response(msg: str) -> ToolResult:
"""创建失败的工具结果"""
return ToolResult(error=msg)
# 将辅助方法添加到 BaseTool
BaseTool.success_response = staticmethod(_success_response)
BaseTool.fail_response = staticmethod(_fail_response)

View File

@@ -0,0 +1,158 @@
"""File operation interfaces and implementations for local and sandbox environments."""
import asyncio
from pathlib import Path
from typing import Optional, Protocol, Tuple, Union, runtime_checkable
from app.config import SandboxSettings
from app.exceptions import ToolError
from app.sandbox.client import SANDBOX_CLIENT
PathLike = Union[str, Path]
@runtime_checkable
class FileOperator(Protocol):
"""Interface for file operations in different environments."""
async def read_file(self, path: PathLike) -> str:
"""Read content from a file."""
...
async def write_file(self, path: PathLike, content: str) -> None:
"""Write content to a file."""
...
async def is_directory(self, path: PathLike) -> bool:
"""Check if path points to a directory."""
...
async def exists(self, path: PathLike) -> bool:
"""Check if path exists."""
...
async def run_command(
self, cmd: str, timeout: Optional[float] = 120.0
) -> Tuple[int, str, str]:
"""Run a shell command and return (return_code, stdout, stderr)."""
...
class LocalFileOperator(FileOperator):
"""File operations implementation for local filesystem."""
encoding: str = "utf-8"
async def read_file(self, path: PathLike) -> str:
"""Read content from a local file."""
try:
return Path(path).read_text(encoding=self.encoding)
except Exception as e:
raise ToolError(f"Failed to read {path}: {str(e)}") from None
async def write_file(self, path: PathLike, content: str) -> None:
"""Write content to a local file."""
try:
Path(path).write_text(content, encoding=self.encoding)
except Exception as e:
raise ToolError(f"Failed to write to {path}: {str(e)}") from None
async def is_directory(self, path: PathLike) -> bool:
"""Check if path points to a directory."""
return Path(path).is_dir()
async def exists(self, path: PathLike) -> bool:
"""Check if path exists."""
return Path(path).exists()
async def run_command(
self, cmd: str, timeout: Optional[float] = 120.0
) -> Tuple[int, str, str]:
"""Run a shell command locally."""
process = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=timeout
)
return (
process.returncode or 0,
stdout.decode(),
stderr.decode(),
)
except asyncio.TimeoutError as exc:
try:
process.kill()
except ProcessLookupError:
pass
raise TimeoutError(
f"Command '{cmd}' timed out after {timeout} seconds"
) from exc
class SandboxFileOperator(FileOperator):
"""File operations implementation for sandbox environment."""
def __init__(self):
self.sandbox_client = SANDBOX_CLIENT
async def _ensure_sandbox_initialized(self):
"""Ensure sandbox is initialized."""
if not self.sandbox_client.sandbox:
await self.sandbox_client.create(config=SandboxSettings())
async def read_file(self, path: PathLike) -> str:
"""Read content from a file in sandbox."""
await self._ensure_sandbox_initialized()
try:
return await self.sandbox_client.read_file(str(path))
except Exception as e:
raise ToolError(f"Failed to read {path} in sandbox: {str(e)}") from None
async def write_file(self, path: PathLike, content: str) -> None:
"""Write content to a file in sandbox."""
await self._ensure_sandbox_initialized()
try:
await self.sandbox_client.write_file(str(path), content)
except Exception as e:
raise ToolError(f"Failed to write to {path} in sandbox: {str(e)}") from None
async def is_directory(self, path: PathLike) -> bool:
"""Check if path points to a directory in sandbox."""
await self._ensure_sandbox_initialized()
result = await self.sandbox_client.run_command(
f"test -d {path} && echo 'true' || echo 'false'"
)
return result.strip() == "true"
async def exists(self, path: PathLike) -> bool:
"""Check if path exists in sandbox."""
await self._ensure_sandbox_initialized()
result = await self.sandbox_client.run_command(
f"test -e {path} && echo 'true' || echo 'false'"
)
return result.strip() == "true"
async def run_command(
self, cmd: str, timeout: Optional[float] = 120.0
) -> Tuple[int, str, str]:
"""Run a command in sandbox environment."""
await self._ensure_sandbox_initialized()
try:
stdout = await self.sandbox_client.run_command(
cmd, timeout=int(timeout) if timeout else None
)
return (
0, # Always return 0 since we don't have explicit return code from sandbox
stdout,
"", # No stderr capture in the current sandbox implementation
)
except TimeoutError as exc:
raise TimeoutError(
f"Command '{cmd}' timed out after {timeout} seconds in sandbox"
) from exc
except Exception as exc:
return 1, "", f"Error executing command in sandbox: {str(exc)}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
"""Collection classes for managing multiple tools."""
import json
from typing import Any, Dict, List, Union
from app.exceptions import ToolError
from app.tool.base import BaseTool, ToolFailure, ToolResult
class ToolCollection:
"""A collection of defined tools."""
class Config:
arbitrary_types_allowed = True
def __init__(self, *tools: BaseTool):
self.tools = tools
self.tool_map = {tool.name: tool for tool in tools}
def __iter__(self):
return iter(self.tools)
def to_params(self) -> List[Dict[str, Any]]:
return [tool.to_param() for tool in self.tools]
async def execute(self, name: str, tool_input: Union[str, dict]) -> ToolResult:
"""Execute a tool by name with given input."""
tool = self.get_tool(name)
if not tool:
return ToolResult(error=f"Tool '{name}' not found")
# 确保 tool_input 是字典类型
if isinstance(tool_input, str):
try:
tool_input = json.loads(tool_input)
except json.JSONDecodeError:
return ToolResult(error=f"Invalid tool input format: {tool_input}")
result = await tool(**tool_input)
return result
async def execute_all(self) -> List[ToolResult]:
"""Execute all tools in the collection sequentially."""
results = []
for tool in self.tools:
try:
result = await tool()
results.append(result)
except ToolError as e:
results.append(ToolFailure(error=e.message))
return results
def get_tool(self, name: str) -> BaseTool:
return self.tool_map.get(name)
def add_tool(self, tool: BaseTool):
self.tools += (tool,)
self.tool_map[tool.name] = tool
return self
def add_tools(self, *tools: BaseTool):
for tool in tools:
self.add_tool(tool)
return self

View File

@@ -0,0 +1,101 @@
import asyncio
from typing import List
from tenacity import retry, stop_after_attempt, wait_exponential
from app.config import config
from app.tool.base import BaseTool
from app.tool.search import (
BaiduSearchEngine,
BingSearchEngine,
DuckDuckGoSearchEngine,
GoogleSearchEngine,
WebSearchEngine,
)
class WebSearch(BaseTool):
name: str = "web_search"
description: str = """Perform a web search and return a list of relevant links.
This function attempts to use the primary search engine API to get up-to-date results.
If an error occurs, it falls back to an alternative search engine."""
parameters: dict = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "(required) The search query to submit to the search engine.",
},
"num_results": {
"type": "integer",
"description": "(optional) The number of search results to return. Default is 10.",
"default": 10,
},
},
"required": ["query"],
}
_search_engine: dict[str, WebSearchEngine] = {
"google": GoogleSearchEngine(),
"baidu": BaiduSearchEngine(),
"duckduckgo": DuckDuckGoSearchEngine(),
"bing": BingSearchEngine(),
}
async def execute(self, query: str, num_results: int = 10) -> List[str]:
"""
Execute a Web search and return a list of URLs.
Args:
query (str): The search query to submit to the search engine.
num_results (int, optional): The number of search results to return. Default is 10.
Returns:
List[str]: A list of URLs matching the search query.
"""
engine_order = self._get_engine_order()
for engine_name in engine_order:
engine = self._search_engine[engine_name]
try:
links = await self._perform_search_with_engine(
engine, query, num_results
)
if links:
return links
except Exception as e:
print(f"Search engine '{engine_name}' failed with error: {e}")
return []
def _get_engine_order(self) -> List[str]:
"""
Determines the order in which to try search engines.
Preferred engine is first (based on configuration), followed by the remaining engines.
Returns:
List[str]: Ordered list of search engine names.
"""
preferred = "google"
if config.search_config and config.search_config.engine:
preferred = config.search_config.engine.lower()
engine_order = []
if preferred in self._search_engine:
engine_order.append(preferred)
for key in self._search_engine:
if key not in engine_order:
engine_order.append(key)
return engine_order
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
)
async def _perform_search_with_engine(
self,
engine: WebSearchEngine,
query: str,
num_results: int,
) -> List[str]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, lambda: list(engine.perform_search(query, num_results=num_results))
)