first commit
This commit is contained in:
9
MeetSpot/app/__init__.py
Normal file
9
MeetSpot/app/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Python version check: 3.11-3.13
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 11) or sys.version_info > (3, 13):
|
||||
print(
|
||||
"Warning: Unsupported Python version {ver}, please use 3.11-3.13".format(
|
||||
ver=".".join(map(str, sys.version_info))
|
||||
)
|
||||
)
|
||||
20
MeetSpot/app/agent/__init__.py
Normal file
20
MeetSpot/app/agent/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""MeetSpot Agent Module - 基于 OpenManus 架构的智能推荐 Agent"""
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.agent.meetspot_agent import MeetSpotAgent, create_meetspot_agent
|
||||
from app.agent.tools import (
|
||||
CalculateCenterTool,
|
||||
GeocodeTool,
|
||||
GenerateRecommendationTool,
|
||||
SearchPOITool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"MeetSpotAgent",
|
||||
"create_meetspot_agent",
|
||||
"GeocodeTool",
|
||||
"CalculateCenterTool",
|
||||
"SearchPOITool",
|
||||
"GenerateRecommendationTool",
|
||||
]
|
||||
171
MeetSpot/app/agent/base.py
Normal file
171
MeetSpot/app/agent/base.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Agent 基类 - 参考 OpenManus BaseAgent 设计"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Memory, Message, ROLE_TYPE
|
||||
|
||||
|
||||
class BaseAgent(BaseModel, ABC):
|
||||
"""Agent 基类
|
||||
|
||||
提供基础的状态管理、记忆管理和执行循环。
|
||||
子类需要实现 step() 方法来定义具体行为。
|
||||
"""
|
||||
|
||||
# 核心属性
|
||||
name: str = Field(default="BaseAgent", description="Agent 名称")
|
||||
description: Optional[str] = Field(default=None, description="Agent 描述")
|
||||
|
||||
# 提示词
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
|
||||
next_step_prompt: Optional[str] = Field(default=None, description="下一步提示词")
|
||||
|
||||
# 依赖
|
||||
llm: Optional[LLM] = Field(default=None, description="LLM 实例")
|
||||
memory: Memory = Field(default_factory=Memory, description="Agent 记忆")
|
||||
state: AgentState = Field(default=AgentState.IDLE, description="当前状态")
|
||||
|
||||
# 执行控制
|
||||
max_steps: int = Field(default=10, description="最大执行步数")
|
||||
current_step: int = Field(default=0, description="当前步数")
|
||||
|
||||
# 重复检测阈值
|
||||
duplicate_threshold: int = 2
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_agent(self) -> "BaseAgent":
|
||||
"""初始化 Agent"""
|
||||
if self.llm is None:
|
||||
try:
|
||||
self.llm = LLM()
|
||||
except Exception as e:
|
||||
logger.warning(f"无法初始化 LLM: {e}")
|
||||
if not isinstance(self.memory, Memory):
|
||||
self.memory = Memory()
|
||||
return self
|
||||
|
||||
@asynccontextmanager
|
||||
async def state_context(self, new_state: AgentState):
|
||||
"""状态上下文管理器,用于安全的状态转换"""
|
||||
if not isinstance(new_state, AgentState):
|
||||
raise ValueError(f"无效状态: {new_state}")
|
||||
|
||||
previous_state = self.state
|
||||
self.state = new_state
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
self.state = AgentState.ERROR
|
||||
raise e
|
||||
finally:
|
||||
self.state = previous_state
|
||||
|
||||
def update_memory(
|
||||
self,
|
||||
role: ROLE_TYPE,
|
||||
content: str,
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""添加消息到记忆"""
|
||||
message_map = {
|
||||
"user": Message.user_message,
|
||||
"system": Message.system_message,
|
||||
"assistant": Message.assistant_message,
|
||||
"tool": lambda content, **kw: Message.tool_message(content, **kw),
|
||||
}
|
||||
|
||||
if role not in message_map:
|
||||
raise ValueError(f"不支持的消息角色: {role}")
|
||||
|
||||
if role == "tool":
|
||||
self.memory.add_message(message_map[role](content, **kwargs))
|
||||
else:
|
||||
self.memory.add_message(message_map[role](content, base64_image=base64_image))
|
||||
|
||||
async def run(self, request: Optional[str] = None) -> str:
|
||||
"""执行 Agent 主循环
|
||||
|
||||
Args:
|
||||
request: 可选的初始用户请求
|
||||
|
||||
Returns:
|
||||
执行结果摘要
|
||||
"""
|
||||
if self.state != AgentState.IDLE:
|
||||
raise RuntimeError(f"无法从状态 {self.state} 启动 Agent")
|
||||
|
||||
if request:
|
||||
self.update_memory("user", request)
|
||||
|
||||
results: List[str] = []
|
||||
async with self.state_context(AgentState.RUNNING):
|
||||
while (
|
||||
self.current_step < self.max_steps
|
||||
and self.state != AgentState.FINISHED
|
||||
):
|
||||
self.current_step += 1
|
||||
logger.info(f"执行步骤 {self.current_step}/{self.max_steps}")
|
||||
step_result = await self.step()
|
||||
|
||||
# 检测卡住状态
|
||||
if self.is_stuck():
|
||||
self.handle_stuck_state()
|
||||
|
||||
results.append(f"Step {self.current_step}: {step_result}")
|
||||
|
||||
if self.current_step >= self.max_steps:
|
||||
self.current_step = 0
|
||||
self.state = AgentState.IDLE
|
||||
results.append(f"已终止: 达到最大步数 ({self.max_steps})")
|
||||
|
||||
return "\n".join(results) if results else "未执行任何步骤"
|
||||
|
||||
@abstractmethod
|
||||
async def step(self) -> str:
|
||||
"""执行单步操作 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
def handle_stuck_state(self):
|
||||
"""处理卡住状态"""
|
||||
stuck_prompt = "检测到重复响应。请考虑新策略,避免重复已尝试过的无效路径。"
|
||||
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt or ''}"
|
||||
logger.warning(f"Agent 检测到卡住状态,已添加提示")
|
||||
|
||||
def is_stuck(self) -> bool:
|
||||
"""检测是否陷入循环"""
|
||||
if len(self.memory.messages) < 2:
|
||||
return False
|
||||
|
||||
last_message = self.memory.messages[-1]
|
||||
if not last_message.content:
|
||||
return False
|
||||
|
||||
# 统计相同内容出现次数
|
||||
duplicate_count = sum(
|
||||
1
|
||||
for msg in reversed(self.memory.messages[:-1])
|
||||
if msg.role == "assistant" and msg.content == last_message.content
|
||||
)
|
||||
|
||||
return duplicate_count >= self.duplicate_threshold
|
||||
|
||||
@property
|
||||
def messages(self) -> List[Message]:
|
||||
"""获取记忆中的消息列表"""
|
||||
return self.memory.messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value: List[Message]):
|
||||
"""设置记忆中的消息列表"""
|
||||
self.memory.messages = value
|
||||
361
MeetSpot/app/agent/meetspot_agent.py
Normal file
361
MeetSpot/app/agent/meetspot_agent.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""MeetSpotAgent - 智能会面地点推荐 Agent
|
||||
|
||||
基于 ReAct 模式实现的智能推荐代理,通过工具调用完成地点推荐任务。
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.agent.tools import (
|
||||
CalculateCenterTool,
|
||||
GeocodeTool,
|
||||
GenerateRecommendationTool,
|
||||
SearchPOITool,
|
||||
)
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Message
|
||||
from app.tool.tool_collection import ToolCollection
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """你是 MeetSpot 智能会面助手,帮助用户找到最佳会面地点。
|
||||
|
||||
## 你的能力
|
||||
你可以使用以下工具来完成任务:
|
||||
|
||||
1. **geocode** - 地理编码
|
||||
- 将地址转换为经纬度坐标
|
||||
- 支持大学简称(北大、清华)、地标、商圈等
|
||||
- 返回坐标和格式化地址
|
||||
|
||||
2. **calculate_center** - 计算中心点
|
||||
- 计算多个位置的几何中心
|
||||
- 作为最佳会面位置的参考点
|
||||
- 使用球面几何确保精确
|
||||
|
||||
3. **search_poi** - 搜索场所
|
||||
- 在中心点附近搜索各类场所
|
||||
- 支持咖啡馆、餐厅、图书馆、健身房等
|
||||
- 返回名称、地址、评分、距离等
|
||||
|
||||
4. **generate_recommendation** - 生成推荐
|
||||
- 分析搜索结果
|
||||
- 根据评分、距离、用户需求排序
|
||||
- 生成个性化推荐理由
|
||||
|
||||
## 工作流程
|
||||
请按以下步骤执行:
|
||||
|
||||
1. **理解任务** - 分析用户提供的位置和需求
|
||||
2. **地理编码** - 依次对每个地址使用 geocode 获取坐标
|
||||
3. **计算中心** - 使用 calculate_center 计算最佳会面点
|
||||
4. **搜索场所** - 使用 search_poi 在中心点附近搜索
|
||||
5. **生成推荐** - 使用 generate_recommendation 生成最终推荐
|
||||
|
||||
## 输出要求
|
||||
- 推荐 3-5 个最佳场所
|
||||
- 为每个场所说明推荐理由(距离、评分、特色)
|
||||
- 考虑用户的特殊需求(停车、安静、商务等)
|
||||
- 使用中文回复
|
||||
|
||||
## 注意事项
|
||||
- 确保在调用工具前已获取所有必要参数
|
||||
- 如果地址解析失败,提供具体的错误信息和建议
|
||||
- 如果搜索无结果,尝试调整搜索关键词或扩大半径
|
||||
"""
|
||||
|
||||
|
||||
class MeetSpotAgent(BaseAgent):
|
||||
"""MeetSpot 智能会面推荐 Agent
|
||||
|
||||
基于 ReAct 模式的智能代理,通过 think() -> act() 循环完成推荐任务。
|
||||
"""
|
||||
|
||||
name: str = "MeetSpotAgent"
|
||||
description: str = "智能会面地点推荐助手"
|
||||
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
next_step_prompt: str = "请继续执行下一步,或者如果已完成所有工具调用,请生成最终推荐结果。"
|
||||
|
||||
max_steps: int = 15 # 允许更多步骤以完成复杂任务
|
||||
|
||||
# 工具集合
|
||||
available_tools: ToolCollection = Field(default=None)
|
||||
|
||||
# 当前工具调用
|
||||
tool_calls: List[Any] = Field(default_factory=list)
|
||||
|
||||
# 存储中间结果
|
||||
geocode_results: List[Dict] = Field(default_factory=list)
|
||||
center_point: Optional[Dict] = None
|
||||
search_results: List[Dict] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
# 初始化工具集合
|
||||
if self.available_tools is None:
|
||||
self.available_tools = ToolCollection(
|
||||
GeocodeTool(),
|
||||
CalculateCenterTool(),
|
||||
SearchPOITool(),
|
||||
GenerateRecommendationTool()
|
||||
)
|
||||
|
||||
async def step(self) -> str:
|
||||
"""执行一步: think + act
|
||||
|
||||
Returns:
|
||||
步骤执行结果的描述
|
||||
"""
|
||||
# Think: 决定下一步行动
|
||||
should_continue = await self.think()
|
||||
|
||||
if not should_continue:
|
||||
self.state = AgentState.FINISHED
|
||||
return "任务完成"
|
||||
|
||||
# Act: 执行工具调用
|
||||
result = await self.act()
|
||||
return result
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""思考阶段 - 决定下一步行动
|
||||
|
||||
使用 LLM 分析当前状态,决定是否需要调用工具以及调用哪个工具。
|
||||
|
||||
Returns:
|
||||
是否需要继续执行
|
||||
"""
|
||||
# 构建消息
|
||||
messages = self.memory.messages.copy()
|
||||
|
||||
# 添加提示引导下一步
|
||||
if self.next_step_prompt and self.current_step > 1:
|
||||
messages.append(Message.user_message(self.next_step_prompt))
|
||||
|
||||
# 调用 LLM 获取响应
|
||||
response = await self.llm.ask_tool(
|
||||
messages=messages,
|
||||
system_msgs=[Message.system_message(self.system_prompt)],
|
||||
tools=self.available_tools.to_params(),
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
if response is None:
|
||||
logger.warning("LLM 返回空响应")
|
||||
return False
|
||||
|
||||
# 提取工具调用和内容
|
||||
self.tool_calls = response.tool_calls or []
|
||||
content = response.content or ""
|
||||
|
||||
logger.info(f"Agent 思考: {content[:200]}..." if len(content) > 200 else f"Agent 思考: {content}")
|
||||
|
||||
if self.tool_calls:
|
||||
tool_names = [tc.function.name for tc in self.tool_calls]
|
||||
logger.info(f"选择工具: {tool_names}")
|
||||
|
||||
# 保存 assistant 消息到记忆
|
||||
if self.tool_calls:
|
||||
# 带工具调用的消息
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
}
|
||||
for tc in self.tool_calls
|
||||
]
|
||||
self.memory.add_message(Message(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls_data
|
||||
))
|
||||
elif content:
|
||||
# 纯文本消息(可能是最终回复)
|
||||
self.memory.add_message(Message.assistant_message(content))
|
||||
# 如果没有工具调用且有内容,可能是最终回复
|
||||
if "推荐" in content and len(content) > 100:
|
||||
return False # 结束循环
|
||||
|
||||
return bool(self.tool_calls) or bool(content)
|
||||
|
||||
async def act(self) -> str:
|
||||
"""行动阶段 - 执行工具调用
|
||||
|
||||
执行思考阶段决定的工具调用,并将结果添加到记忆。
|
||||
|
||||
Returns:
|
||||
工具执行结果的描述
|
||||
"""
|
||||
if not self.tool_calls:
|
||||
# 没有工具调用,返回最后一条消息的内容
|
||||
return self.memory.messages[-1].content or "无操作"
|
||||
|
||||
results = []
|
||||
for call in self.tool_calls:
|
||||
tool_name = call.function.name
|
||||
tool_args = call.function.arguments
|
||||
|
||||
try:
|
||||
# 解析参数
|
||||
args = json.loads(tool_args) if isinstance(tool_args, str) else tool_args
|
||||
|
||||
# 执行工具
|
||||
logger.info(f"执行工具: {tool_name}, 参数: {args}")
|
||||
result = await self.available_tools.execute(name=tool_name, tool_input=args)
|
||||
|
||||
# 保存中间结果
|
||||
self._save_intermediate_result(tool_name, result, args)
|
||||
|
||||
# 将工具结果添加到记忆
|
||||
result_str = str(result)
|
||||
self.memory.add_message(Message.tool_message(
|
||||
content=result_str,
|
||||
tool_call_id=call.id,
|
||||
name=tool_name
|
||||
))
|
||||
|
||||
logger.info(f"工具 {tool_name} 完成")
|
||||
results.append(f"{tool_name}: 成功")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"工具执行失败: {str(e)}"
|
||||
logger.error(f"{tool_name} {error_msg}")
|
||||
|
||||
# 添加错误消息到记忆
|
||||
self.memory.add_message(Message.tool_message(
|
||||
content=error_msg,
|
||||
tool_call_id=call.id,
|
||||
name=tool_name
|
||||
))
|
||||
results.append(f"{tool_name}: 失败 - {str(e)}")
|
||||
|
||||
return " | ".join(results)
|
||||
|
||||
def _save_intermediate_result(self, tool_name: str, result: Any, args: Dict) -> None:
|
||||
"""保存工具执行的中间结果
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
result: 工具执行结果
|
||||
args: 工具参数
|
||||
"""
|
||||
try:
|
||||
# 解析结果
|
||||
if hasattr(result, 'output') and result.output:
|
||||
data = json.loads(result.output) if isinstance(result.output, str) else result.output
|
||||
else:
|
||||
return
|
||||
|
||||
if tool_name == "geocode" and data:
|
||||
self.geocode_results.append({
|
||||
"address": args.get("address", ""),
|
||||
"lng": data.get("lng"),
|
||||
"lat": data.get("lat"),
|
||||
"formatted_address": data.get("formatted_address", "")
|
||||
})
|
||||
|
||||
elif tool_name == "calculate_center" and data:
|
||||
self.center_point = data.get("center")
|
||||
|
||||
elif tool_name == "search_poi" and data:
|
||||
places = data.get("places", [])
|
||||
self.search_results.extend(places)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"保存中间结果时出错: {e}")
|
||||
|
||||
async def recommend(
|
||||
self,
|
||||
locations: List[str],
|
||||
keywords: str = "咖啡馆",
|
||||
requirements: str = ""
|
||||
) -> Dict:
|
||||
"""执行推荐任务
|
||||
|
||||
这是 Agent 的主要入口方法,接收用户输入并返回推荐结果。
|
||||
|
||||
Args:
|
||||
locations: 参与者位置列表
|
||||
keywords: 搜索关键词(场所类型)
|
||||
requirements: 用户特殊需求
|
||||
|
||||
Returns:
|
||||
包含推荐结果的字典
|
||||
"""
|
||||
# 重置状态
|
||||
self.geocode_results = []
|
||||
self.center_point = None
|
||||
self.search_results = []
|
||||
self.current_step = 0
|
||||
self.state = AgentState.IDLE
|
||||
self.memory.clear()
|
||||
|
||||
# 构建任务描述
|
||||
locations_str = "、".join(locations)
|
||||
task = f"""请帮我找到适合会面的地点:
|
||||
|
||||
**参与者位置**:{locations_str}
|
||||
**想找的场所类型**:{keywords}
|
||||
**特殊需求**:{requirements or "无特殊需求"}
|
||||
|
||||
请按照工作流程执行:
|
||||
1. 先用 geocode 工具获取每个位置的坐标
|
||||
2. 用 calculate_center 计算中心点
|
||||
3. 用 search_poi 搜索附近的 {keywords}
|
||||
4. 用 generate_recommendation 生成推荐
|
||||
|
||||
最后请用中文总结推荐结果。"""
|
||||
|
||||
# 执行任务
|
||||
result = await self.run(task)
|
||||
|
||||
# 格式化返回结果
|
||||
return self._format_result(result)
|
||||
|
||||
def _format_result(self, raw_result: str) -> Dict:
|
||||
"""格式化最终结果
|
||||
|
||||
Args:
|
||||
raw_result: Agent 执行的原始结果
|
||||
|
||||
Returns:
|
||||
格式化的结果字典
|
||||
"""
|
||||
# 获取最后一条 assistant 消息作为最终推荐
|
||||
final_recommendation = ""
|
||||
for msg in reversed(self.memory.messages):
|
||||
if msg.role == "assistant" and msg.content:
|
||||
final_recommendation = msg.content
|
||||
break
|
||||
|
||||
return {
|
||||
"success": self.state == AgentState.IDLE, # IDLE 表示正常完成
|
||||
"recommendation": final_recommendation,
|
||||
"geocode_results": self.geocode_results,
|
||||
"center_point": self.center_point,
|
||||
"search_results": self.search_results[:10], # 限制返回数量
|
||||
"steps_executed": self.current_step,
|
||||
"raw_output": raw_result
|
||||
}
|
||||
|
||||
|
||||
# 创建默认 Agent 实例的工厂函数
|
||||
def create_meetspot_agent() -> MeetSpotAgent:
|
||||
"""创建 MeetSpotAgent 实例
|
||||
|
||||
Returns:
|
||||
配置好的 MeetSpotAgent 实例
|
||||
"""
|
||||
return MeetSpotAgent()
|
||||
514
MeetSpot/app/agent/tools.py
Normal file
514
MeetSpot/app/agent/tools.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""MeetSpot Agent 工具集 - 封装推荐系统的核心功能"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.tool.base import BaseTool, ToolResult
|
||||
from app.logger import logger
|
||||
|
||||
|
||||
class GeocodeTool(BaseTool):
|
||||
"""地理编码工具 - 将地址转换为经纬度坐标"""
|
||||
|
||||
name: str = "geocode"
|
||||
description: str = """将地址或地点名称转换为经纬度坐标。
|
||||
支持各种地址格式:
|
||||
- 完整地址:'北京市海淀区中关村大街1号'
|
||||
- 大学简称:'北大'、'清华'、'复旦'(自动扩展为完整地址)
|
||||
- 知名地标:'天安门'、'外滩'、'广州塔'
|
||||
- 商圈区域:'三里屯'、'王府井'
|
||||
|
||||
返回地址的经纬度坐标和格式化地址。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "string",
|
||||
"description": "地址或地点名称,如'北京大学'、'上海市浦东新区陆家嘴'"
|
||||
}
|
||||
},
|
||||
"required": ["address"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
# 确保 API key 已设置
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(self, address: str) -> ToolResult:
|
||||
"""执行地理编码"""
|
||||
try:
|
||||
recommender = self._get_recommender()
|
||||
result = await recommender._geocode(address)
|
||||
|
||||
if result:
|
||||
location = result.get("location", "")
|
||||
lng, lat = location.split(",") if location else (None, None)
|
||||
|
||||
return BaseTool.success_response({
|
||||
"address": address,
|
||||
"formatted_address": result.get("formatted_address", ""),
|
||||
"location": location,
|
||||
"lng": float(lng) if lng else None,
|
||||
"lat": float(lat) if lat else None,
|
||||
"city": result.get("city", ""),
|
||||
"district": result.get("district", "")
|
||||
})
|
||||
|
||||
return BaseTool.fail_response(f"无法解析地址: {address}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"地理编码失败: {e}")
|
||||
return BaseTool.fail_response(f"地理编码错误: {str(e)}")
|
||||
|
||||
|
||||
class CalculateCenterTool(BaseTool):
|
||||
"""智能中心点工具 - 计算多个位置的最佳会面点
|
||||
|
||||
使用智能算法,综合考虑:
|
||||
- POI 密度:周边是否有足够的目标场所
|
||||
- 交通便利性:是否靠近地铁站/公交站
|
||||
- 公平性:对所有参与者的距离是否均衡
|
||||
"""
|
||||
|
||||
name: str = "calculate_center"
|
||||
description: str = """智能计算最佳会面中心点。
|
||||
|
||||
不同于简单的几何中心,本工具会:
|
||||
1. 在几何中心周围生成多个候选点
|
||||
2. 评估每个候选点的 POI 密度、交通便利性和公平性
|
||||
3. 返回综合得分最高的点作为最佳会面位置
|
||||
|
||||
这样可以避免中心点落在河流、荒地等不适合的位置。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"coordinates": {
|
||||
"type": "array",
|
||||
"description": "坐标点列表,每个元素包含 lng(经度)、lat(纬度)和可选的 name(名称)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lng": {"type": "number", "description": "经度"},
|
||||
"lat": {"type": "number", "description": "纬度"},
|
||||
"name": {"type": "string", "description": "位置名称(可选)"}
|
||||
},
|
||||
"required": ["lng", "lat"]
|
||||
}
|
||||
},
|
||||
"keywords": {
|
||||
"type": "string",
|
||||
"description": "搜索的场所类型,如'咖啡馆'、'餐厅',用于评估 POI 密度",
|
||||
"default": "咖啡馆"
|
||||
},
|
||||
"use_smart_algorithm": {
|
||||
"type": "boolean",
|
||||
"description": "是否使用智能算法(考虑 POI 密度和交通),默认 true",
|
||||
"default": True
|
||||
}
|
||||
},
|
||||
"required": ["coordinates"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
coordinates: List[Dict],
|
||||
keywords: str = "咖啡馆",
|
||||
use_smart_algorithm: bool = True
|
||||
) -> ToolResult:
|
||||
"""计算最佳中心点"""
|
||||
try:
|
||||
if not coordinates or len(coordinates) < 2:
|
||||
return BaseTool.fail_response("至少需要2个坐标点来计算中心")
|
||||
|
||||
recommender = self._get_recommender()
|
||||
|
||||
# 转换为 (lng, lat) 元组列表
|
||||
coord_tuples = [(c["lng"], c["lat"]) for c in coordinates]
|
||||
|
||||
if use_smart_algorithm:
|
||||
# 使用智能中心点算法
|
||||
center, evaluation_details = await recommender._calculate_smart_center(
|
||||
coord_tuples, keywords
|
||||
)
|
||||
logger.info(f"智能中心点算法完成,最优中心: {center}")
|
||||
else:
|
||||
# 使用简单几何中心
|
||||
center = recommender._calculate_center_point(coord_tuples)
|
||||
evaluation_details = {"algorithm": "geometric_center"}
|
||||
|
||||
# 计算每个点到中心的距离
|
||||
distances = []
|
||||
for c in coordinates:
|
||||
dist = recommender._calculate_distance(center, (c["lng"], c["lat"]))
|
||||
distances.append({
|
||||
"name": c.get("name", f"({c['lng']:.4f}, {c['lat']:.4f})"),
|
||||
"distance_to_center": round(dist, 0)
|
||||
})
|
||||
|
||||
max_dist = max(d["distance_to_center"] for d in distances)
|
||||
min_dist = min(d["distance_to_center"] for d in distances)
|
||||
|
||||
result = {
|
||||
"center": {
|
||||
"lng": round(center[0], 6),
|
||||
"lat": round(center[1], 6)
|
||||
},
|
||||
"algorithm": "smart" if use_smart_algorithm else "geometric",
|
||||
"input_count": len(coordinates),
|
||||
"distances": distances,
|
||||
"max_distance": max_dist,
|
||||
"fairness_score": round(100 - (max_dist - min_dist) / 100, 1)
|
||||
}
|
||||
|
||||
# 添加智能算法的评估详情
|
||||
if use_smart_algorithm and evaluation_details:
|
||||
result["evaluation"] = {
|
||||
"geo_center": evaluation_details.get("geo_center"),
|
||||
"best_score": evaluation_details.get("best_score"),
|
||||
"top_candidates": len(evaluation_details.get("all_candidates", []))
|
||||
}
|
||||
|
||||
return BaseTool.success_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算中心点失败: {e}")
|
||||
return BaseTool.fail_response(f"计算中心点错误: {str(e)}")
|
||||
|
||||
|
||||
class SearchPOITool(BaseTool):
|
||||
"""搜索POI工具 - 在指定位置周围搜索场所"""
|
||||
|
||||
name: str = "search_poi"
|
||||
description: str = """在指定中心点周围搜索各类场所(POI)。
|
||||
支持搜索:咖啡馆、餐厅、图书馆、健身房、KTV、电影院、商场等。
|
||||
返回场所的名称、地址、评分、距离等信息。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"center_lng": {
|
||||
"type": "number",
|
||||
"description": "中心点经度"
|
||||
},
|
||||
"center_lat": {
|
||||
"type": "number",
|
||||
"description": "中心点纬度"
|
||||
},
|
||||
"keywords": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词,如'咖啡馆'、'餐厅'、'图书馆'"
|
||||
},
|
||||
"radius": {
|
||||
"type": "integer",
|
||||
"description": "搜索半径(米),默认3000米",
|
||||
"default": 3000
|
||||
}
|
||||
},
|
||||
"required": ["center_lng", "center_lat", "keywords"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
center_lng: float,
|
||||
center_lat: float,
|
||||
keywords: str,
|
||||
radius: int = 3000
|
||||
) -> ToolResult:
|
||||
"""搜索POI"""
|
||||
try:
|
||||
recommender = self._get_recommender()
|
||||
center = f"{center_lng},{center_lat}"
|
||||
|
||||
places = await recommender._search_pois(
|
||||
location=center,
|
||||
keywords=keywords,
|
||||
radius=radius,
|
||||
types="",
|
||||
offset=20
|
||||
)
|
||||
|
||||
if not places:
|
||||
return BaseTool.fail_response(
|
||||
f"在 ({center_lng:.4f}, {center_lat:.4f}) 附近 {radius}米范围内"
|
||||
f"未找到与 '{keywords}' 相关的场所"
|
||||
)
|
||||
|
||||
# 简化返回数据
|
||||
simplified = []
|
||||
for p in places[:15]: # 最多返回15个
|
||||
biz_ext = p.get("biz_ext", {}) or {}
|
||||
location = p.get("location", "")
|
||||
lng, lat = location.split(",") if location else (0, 0)
|
||||
|
||||
# 计算到中心的距离
|
||||
distance = recommender._calculate_distance(
|
||||
(center_lng, center_lat),
|
||||
(float(lng), float(lat))
|
||||
) if location else 0
|
||||
|
||||
simplified.append({
|
||||
"name": p.get("name", ""),
|
||||
"address": p.get("address", ""),
|
||||
"rating": biz_ext.get("rating", "N/A"),
|
||||
"cost": biz_ext.get("cost", ""),
|
||||
"location": location,
|
||||
"lng": float(lng) if lng else None,
|
||||
"lat": float(lat) if lat else None,
|
||||
"distance": round(distance, 0),
|
||||
"tel": p.get("tel", ""),
|
||||
"tag": p.get("tag", ""),
|
||||
"type": p.get("type", "")
|
||||
})
|
||||
|
||||
# 按距离排序
|
||||
simplified.sort(key=lambda x: x.get("distance", 9999))
|
||||
|
||||
return BaseTool.success_response({
|
||||
"places": simplified,
|
||||
"count": len(simplified),
|
||||
"keywords": keywords,
|
||||
"center": {"lng": center_lng, "lat": center_lat},
|
||||
"radius": radius
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"POI搜索失败: {e}")
|
||||
return BaseTool.fail_response(f"POI搜索错误: {str(e)}")
|
||||
|
||||
|
||||
class GenerateRecommendationTool(BaseTool):
|
||||
"""智能推荐工具 - 使用 LLM 生成个性化推荐结果
|
||||
|
||||
结合规则评分和 LLM 智能评分,生成更精准的推荐:
|
||||
- 规则评分:基于距离、评分、热度等客观指标
|
||||
- LLM 评分:理解用户需求语义,评估场所匹配度
|
||||
"""
|
||||
|
||||
name: str = "generate_recommendation"
|
||||
description: str = """智能生成会面地点推荐。
|
||||
|
||||
本工具使用双层评分系统:
|
||||
1. 规则评分(40%):基于距离、评分、热度等客观指标
|
||||
2. LLM 智能评分(60%):理解用户需求,评估场所特色与需求的匹配度
|
||||
|
||||
最终生成个性化的推荐理由,帮助用户做出最佳选择。"""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"places": {
|
||||
"type": "array",
|
||||
"description": "候选场所列表(来自search_poi的结果)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "场所名称"},
|
||||
"address": {"type": "string", "description": "地址"},
|
||||
"rating": {"type": "string", "description": "评分"},
|
||||
"distance": {"type": "number", "description": "距中心点距离"},
|
||||
"location": {"type": "string", "description": "坐标"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"center": {
|
||||
"type": "object",
|
||||
"description": "中心点坐标",
|
||||
"properties": {
|
||||
"lng": {"type": "number", "description": "经度"},
|
||||
"lat": {"type": "number", "description": "纬度"}
|
||||
},
|
||||
"required": ["lng", "lat"]
|
||||
},
|
||||
"participant_locations": {
|
||||
"type": "array",
|
||||
"description": "参与者位置名称列表,用于 LLM 评估公平性",
|
||||
"items": {"type": "string"},
|
||||
"default": []
|
||||
},
|
||||
"keywords": {
|
||||
"type": "string",
|
||||
"description": "搜索的场所类型,如'咖啡馆'、'餐厅'",
|
||||
"default": "咖啡馆"
|
||||
},
|
||||
"user_requirements": {
|
||||
"type": "string",
|
||||
"description": "用户的特殊需求,如'停车方便'、'环境安静'",
|
||||
"default": ""
|
||||
},
|
||||
"recommendation_count": {
|
||||
"type": "integer",
|
||||
"description": "推荐数量,默认5个",
|
||||
"default": 5
|
||||
},
|
||||
"use_llm_ranking": {
|
||||
"type": "boolean",
|
||||
"description": "是否使用 LLM 智能排序,默认 true",
|
||||
"default": True
|
||||
}
|
||||
},
|
||||
"required": ["places", "center"]
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_recommender(self):
|
||||
"""延迟加载推荐器,并确保 API key 已设置"""
|
||||
if not hasattr(self, '_cached_recommender'):
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.config import config
|
||||
recommender = CafeRecommender()
|
||||
if hasattr(config, 'amap') and config.amap and hasattr(config.amap, 'api_key'):
|
||||
recommender.api_key = config.amap.api_key
|
||||
object.__setattr__(self, '_cached_recommender', recommender)
|
||||
return self._cached_recommender
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
places: List[Dict],
|
||||
center: Dict,
|
||||
participant_locations: List[str] = None,
|
||||
keywords: str = "咖啡馆",
|
||||
user_requirements: str = "",
|
||||
recommendation_count: int = 5,
|
||||
use_llm_ranking: bool = True
|
||||
) -> ToolResult:
|
||||
"""智能生成推荐"""
|
||||
try:
|
||||
if not places:
|
||||
return BaseTool.fail_response("没有候选场所可供推荐")
|
||||
|
||||
recommender = self._get_recommender()
|
||||
center_point = (center["lng"], center["lat"])
|
||||
|
||||
# 1. 先用规则评分进行初步排序
|
||||
ranked = recommender._rank_places(
|
||||
places=places,
|
||||
center_point=center_point,
|
||||
user_requirements=user_requirements,
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
# 2. 如果启用 LLM 智能排序,进行重排序
|
||||
if use_llm_ranking and participant_locations:
|
||||
logger.info("启用 LLM 智能排序")
|
||||
ranked = await recommender._llm_smart_ranking(
|
||||
places=ranked,
|
||||
user_requirements=user_requirements,
|
||||
participant_locations=participant_locations or [],
|
||||
keywords=keywords,
|
||||
top_n=recommendation_count + 3 # 多取几个以便筛选
|
||||
)
|
||||
|
||||
# 取前N个推荐
|
||||
top_places = ranked[:recommendation_count]
|
||||
|
||||
# 生成推荐结果
|
||||
recommendations = []
|
||||
for i, place in enumerate(top_places, 1):
|
||||
score = place.get("_final_score") or place.get("_score", 0)
|
||||
distance = place.get("_distance") or place.get("distance", 0)
|
||||
rating = place.get("_raw_rating") or place.get("rating", "N/A")
|
||||
|
||||
# 优先使用 LLM 生成的理由
|
||||
llm_reason = place.get("_llm_reason", "")
|
||||
rule_reason = place.get("_recommendation_reason", "")
|
||||
|
||||
if llm_reason:
|
||||
reasons = [llm_reason]
|
||||
elif rule_reason:
|
||||
reasons = [rule_reason]
|
||||
else:
|
||||
# 兜底:构建基础推荐理由
|
||||
reasons = []
|
||||
if distance <= 500:
|
||||
reasons.append("距离中心点很近")
|
||||
elif distance <= 1000:
|
||||
reasons.append("距离适中")
|
||||
|
||||
if rating != "N/A":
|
||||
try:
|
||||
r = float(rating)
|
||||
if r >= 4.5:
|
||||
reasons.append("口碑优秀")
|
||||
elif r >= 4.0:
|
||||
reasons.append("评价良好")
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if not reasons:
|
||||
reasons = ["综合评分较高"]
|
||||
|
||||
recommendations.append({
|
||||
"rank": i,
|
||||
"name": place.get("name", ""),
|
||||
"address": place.get("address", ""),
|
||||
"rating": str(rating) if rating else "N/A",
|
||||
"distance": round(distance, 0),
|
||||
"score": round(score, 1),
|
||||
"llm_score": place.get("_llm_score", 0),
|
||||
"tel": place.get("tel", ""),
|
||||
"reasons": reasons,
|
||||
"location": place.get("location", ""),
|
||||
"scoring_method": "llm+rule" if place.get("_llm_score") else "rule"
|
||||
})
|
||||
|
||||
return BaseTool.success_response({
|
||||
"recommendations": recommendations,
|
||||
"total_candidates": len(places),
|
||||
"user_requirements": user_requirements,
|
||||
"center": center,
|
||||
"llm_ranking_used": use_llm_ranking and bool(participant_locations)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成推荐失败: {e}")
|
||||
return BaseTool.fail_response(f"生成推荐错误: {str(e)}")
|
||||
|
||||
|
||||
# 导出所有工具
|
||||
__all__ = [
|
||||
"GeocodeTool",
|
||||
"CalculateCenterTool",
|
||||
"SearchPOITool",
|
||||
"GenerateRecommendationTool"
|
||||
]
|
||||
2
MeetSpot/app/auth/__init__.py
Normal file
2
MeetSpot/app/auth/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""认证相关模块。"""
|
||||
|
||||
58
MeetSpot/app/auth/jwt.py
Normal file
58
MeetSpot/app/auth/jwt.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""JWT 工具函数与统一用户依赖。"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.crud import get_user_by_id
|
||||
from app.db.database import get_db
|
||||
|
||||
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "meetspot-dev-secret")
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_DAYS = int(os.getenv("ACCESS_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""生成带过期时间的JWT。"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
"""解码并验证JWT。"""
|
||||
try:
|
||||
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""FastAPI 依赖:获取当前用户。"""
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="缺少认证信息"
|
||||
)
|
||||
|
||||
payload = decode_token(credentials.credentials)
|
||||
if not payload or "sub" not in payload:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌无效")
|
||||
|
||||
user = await get_user_by_id(db, payload["sub"])
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
|
||||
|
||||
return user
|
||||
|
||||
24
MeetSpot/app/auth/sms.py
Normal file
24
MeetSpot/app/auth/sms.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""短信验证码(Mock版)。"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
MOCK_CODE = "123456"
|
||||
_code_store: Dict[str, str] = {}
|
||||
|
||||
|
||||
async def send_login_code(phone: str) -> str:
|
||||
"""Mock发送验证码,固定返回`123456`。
|
||||
|
||||
- 真实环境可替换为短信网关调用
|
||||
- 这里简单记忆最后一次下发的验证码,便于后续校验扩展
|
||||
"""
|
||||
|
||||
_code_store[phone] = MOCK_CODE
|
||||
return MOCK_CODE
|
||||
|
||||
|
||||
def validate_code(phone: str, code: str) -> bool:
|
||||
"""校验验证码,MVP阶段固定匹配Mock值。"""
|
||||
return code == MOCK_CODE
|
||||
|
||||
315
MeetSpot/app/config.py
Normal file
315
MeetSpot/app/config.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import threading
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get the project root directory"""
|
||||
return Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
PROJECT_ROOT = get_project_root()
|
||||
WORKSPACE_ROOT = PROJECT_ROOT / "workspace"
|
||||
|
||||
|
||||
class LLMSettings(BaseModel):
|
||||
model: str = Field(..., description="Model name")
|
||||
base_url: str = Field(..., description="API base URL")
|
||||
api_key: str = Field(..., description="API key")
|
||||
max_tokens: int = Field(4096, description="Maximum number of tokens per request")
|
||||
max_input_tokens: Optional[int] = Field(
|
||||
None,
|
||||
description="Maximum input tokens to use across all requests (None for unlimited)",
|
||||
)
|
||||
temperature: float = Field(1.0, description="Sampling temperature")
|
||||
api_type: str = Field(..., description="Azure, Openai, or Ollama")
|
||||
api_version: str = Field(..., description="Azure Openai version if AzureOpenai")
|
||||
|
||||
|
||||
class ProxySettings(BaseModel):
|
||||
server: str = Field(None, description="Proxy server address")
|
||||
username: Optional[str] = Field(None, description="Proxy username")
|
||||
password: Optional[str] = Field(None, description="Proxy password")
|
||||
|
||||
|
||||
class SearchSettings(BaseModel):
|
||||
engine: str = Field(default="Google", description="Search engine the llm to use")
|
||||
|
||||
|
||||
class AMapSettings(BaseModel):
|
||||
"""高德地图API配置"""
|
||||
api_key: str = Field(..., description="高德地图API密钥")
|
||||
web_api_key: Optional[str] = Field(None, description="高德地图JavaScript API密钥")
|
||||
|
||||
|
||||
class BrowserSettings(BaseModel):
|
||||
headless: bool = Field(False, description="Whether to run browser in headless mode")
|
||||
disable_security: bool = Field(
|
||||
True, description="Disable browser security features"
|
||||
)
|
||||
extra_chromium_args: List[str] = Field(
|
||||
default_factory=list, description="Extra arguments to pass to the browser"
|
||||
)
|
||||
chrome_instance_path: Optional[str] = Field(
|
||||
None, description="Path to a Chrome instance to use"
|
||||
)
|
||||
wss_url: Optional[str] = Field(
|
||||
None, description="Connect to a browser instance via WebSocket"
|
||||
)
|
||||
cdp_url: Optional[str] = Field(
|
||||
None, description="Connect to a browser instance via CDP"
|
||||
)
|
||||
proxy: Optional[ProxySettings] = Field(
|
||||
None, description="Proxy settings for the browser"
|
||||
)
|
||||
max_content_length: int = Field(
|
||||
2000, description="Maximum length for content retrieval operations"
|
||||
)
|
||||
|
||||
|
||||
class SandboxSettings(BaseModel):
|
||||
"""Configuration for the execution sandbox"""
|
||||
|
||||
use_sandbox: bool = Field(False, description="Whether to use the sandbox")
|
||||
image: str = Field("python:3.12-slim", description="Base image")
|
||||
work_dir: str = Field("/workspace", description="Container working directory")
|
||||
memory_limit: str = Field("512m", description="Memory limit")
|
||||
cpu_limit: float = Field(1.0, description="CPU limit")
|
||||
timeout: int = Field(300, description="Default command timeout (seconds)")
|
||||
network_enabled: bool = Field(
|
||||
False, description="Whether network access is allowed"
|
||||
)
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
llm: Dict[str, LLMSettings]
|
||||
sandbox: Optional[SandboxSettings] = Field(
|
||||
None, description="Sandbox configuration"
|
||||
)
|
||||
browser_config: Optional[BrowserSettings] = Field(
|
||||
None, description="Browser configuration"
|
||||
)
|
||||
search_config: Optional[SearchSettings] = Field(
|
||||
None, description="Search configuration"
|
||||
)
|
||||
amap: Optional[AMapSettings] = Field(
|
||||
None, description="高德地图API配置"
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Config:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._config = None
|
||||
self._load_initial_config()
|
||||
self._initialized = True
|
||||
|
||||
@staticmethod
|
||||
def _get_config_path() -> Path:
|
||||
root = PROJECT_ROOT
|
||||
config_path = root / "config" / "config.toml"
|
||||
if config_path.exists():
|
||||
return config_path
|
||||
example_path = root / "config" / "config.toml.example"
|
||||
if example_path.exists():
|
||||
return example_path
|
||||
# 如果都没有找到,返回默认路径,让后续创建默认配置
|
||||
return config_path
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
try:
|
||||
config_path = self._get_config_path()
|
||||
if not config_path.exists():
|
||||
# 创建默认配置
|
||||
default_config = {
|
||||
"llm": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 1.0,
|
||||
"api_type": "",
|
||||
"api_version": ""
|
||||
},
|
||||
"amap": {
|
||||
"api_key": "",
|
||||
"security_js_code": ""
|
||||
},
|
||||
"log": {
|
||||
"level": "info",
|
||||
"file": "logs/meetspot.log"
|
||||
},
|
||||
"server": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 8000
|
||||
}
|
||||
}
|
||||
return default_config
|
||||
|
||||
with config_path.open("rb") as f:
|
||||
return tomllib.load(f)
|
||||
except Exception as e:
|
||||
# 如果加载失败,返回默认配置
|
||||
print(f"Failed to load config file, using defaults: {e}")
|
||||
return {
|
||||
"llm": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 1.0,
|
||||
"api_type": "",
|
||||
"api_version": ""
|
||||
}
|
||||
}
|
||||
|
||||
def _load_initial_config(self):
|
||||
raw_config = self._load_config()
|
||||
base_llm = raw_config.get("llm", {})
|
||||
|
||||
# 从环境变量读取敏感信息
|
||||
import os
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY", "") or os.getenv("LLM_API_KEY", "")
|
||||
amap_api_key = os.getenv("AMAP_API_KEY", "")
|
||||
# 支持 Render 部署的环境变量配置
|
||||
llm_base_url = os.getenv("LLM_API_BASE", "") or base_llm.get("base_url", "")
|
||||
llm_model = os.getenv("LLM_MODEL", "") or base_llm.get("model", "gpt-3.5-turbo")
|
||||
|
||||
llm_overrides = {
|
||||
k: v for k, v in raw_config.get("llm", {}).items() if isinstance(v, dict)
|
||||
}
|
||||
|
||||
default_settings = {
|
||||
"model": llm_model, # 优先使用环境变量
|
||||
"base_url": llm_base_url, # 优先使用环境变量
|
||||
"api_key": openai_api_key, # 从环境变量获取
|
||||
"max_tokens": base_llm.get("max_tokens", 4096),
|
||||
"max_input_tokens": base_llm.get("max_input_tokens"),
|
||||
"temperature": base_llm.get("temperature", 1.0),
|
||||
"api_type": base_llm.get("api_type", ""),
|
||||
"api_version": base_llm.get("api_version", ""),
|
||||
}
|
||||
|
||||
# handle browser config.
|
||||
browser_config = raw_config.get("browser", {})
|
||||
browser_settings = None
|
||||
|
||||
if browser_config:
|
||||
# handle proxy settings.
|
||||
proxy_config = browser_config.get("proxy", {})
|
||||
proxy_settings = None
|
||||
|
||||
if proxy_config and proxy_config.get("server"):
|
||||
proxy_settings = ProxySettings(
|
||||
**{
|
||||
k: v
|
||||
for k, v in proxy_config.items()
|
||||
if k in ["server", "username", "password"] and v
|
||||
}
|
||||
)
|
||||
|
||||
# filter valid browser config parameters.
|
||||
valid_browser_params = {
|
||||
k: v
|
||||
for k, v in browser_config.items()
|
||||
if k in BrowserSettings.__annotations__ and v is not None
|
||||
}
|
||||
|
||||
# if there is proxy settings, add it to the parameters.
|
||||
if proxy_settings:
|
||||
valid_browser_params["proxy"] = proxy_settings
|
||||
|
||||
# only create BrowserSettings when there are valid parameters.
|
||||
if valid_browser_params:
|
||||
browser_settings = BrowserSettings(**valid_browser_params)
|
||||
|
||||
search_config = raw_config.get("search", {})
|
||||
search_settings = None
|
||||
if search_config:
|
||||
search_settings = SearchSettings(**search_config)
|
||||
sandbox_config = raw_config.get("sandbox", {})
|
||||
if sandbox_config:
|
||||
sandbox_settings = SandboxSettings(**sandbox_config)
|
||||
else:
|
||||
sandbox_settings = SandboxSettings()
|
||||
|
||||
# 处理高德地图API配置
|
||||
amap_config = raw_config.get("amap", {})
|
||||
amap_settings = None
|
||||
# 优先使用环境变量中的 AMAP_API_KEY
|
||||
if amap_api_key:
|
||||
amap_settings = AMapSettings(
|
||||
api_key=amap_api_key,
|
||||
security_js_code=os.getenv("AMAP_SECURITY_JS_CODE", amap_config.get("security_js_code", ""))
|
||||
)
|
||||
elif amap_config and amap_config.get("api_key"):
|
||||
amap_settings = AMapSettings(**amap_config)
|
||||
|
||||
config_dict = {
|
||||
"llm": {
|
||||
"default": default_settings,
|
||||
**{
|
||||
name: {**default_settings, **override_config}
|
||||
for name, override_config in llm_overrides.items()
|
||||
},
|
||||
},
|
||||
"sandbox": sandbox_settings,
|
||||
"browser_config": browser_settings,
|
||||
"search_config": search_settings,
|
||||
"amap": amap_settings,
|
||||
}
|
||||
|
||||
self._config = AppConfig(**config_dict)
|
||||
|
||||
@property
|
||||
def llm(self) -> Dict[str, LLMSettings]:
|
||||
return self._config.llm
|
||||
|
||||
@property
|
||||
def sandbox(self) -> SandboxSettings:
|
||||
return self._config.sandbox
|
||||
|
||||
@property
|
||||
def browser_config(self) -> Optional[BrowserSettings]:
|
||||
return self._config.browser_config
|
||||
|
||||
@property
|
||||
def search_config(self) -> Optional[SearchSettings]:
|
||||
return self._config.search_config
|
||||
|
||||
@property
|
||||
def amap(self) -> Optional[AMapSettings]:
|
||||
"""获取高德地图API配置"""
|
||||
return self._config.amap
|
||||
|
||||
@property
|
||||
def workspace_root(self) -> Path:
|
||||
"""Get the workspace root directory"""
|
||||
return WORKSPACE_ROOT
|
||||
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""Get the root path of the application"""
|
||||
return PROJECT_ROOT
|
||||
|
||||
|
||||
config = Config()
|
||||
120
MeetSpot/app/config_simple.py
Normal file
120
MeetSpot/app/config_simple.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
import threading
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get the project root directory"""
|
||||
return Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
PROJECT_ROOT = get_project_root()
|
||||
WORKSPACE_ROOT = PROJECT_ROOT / "workspace"
|
||||
|
||||
|
||||
class AMapSettings(BaseModel):
|
||||
"""高德地图API配置"""
|
||||
api_key: str = Field(..., description="高德地图API密钥")
|
||||
security_js_code: Optional[str] = Field(None, description="高德地图JavaScript API安全密钥")
|
||||
|
||||
|
||||
class LogSettings(BaseModel):
|
||||
"""日志配置"""
|
||||
level: str = Field(default="INFO", description="日志级别")
|
||||
file_path: str = Field(default="logs/meetspot.log", description="日志文件路径")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""应用配置"""
|
||||
amap: AMapSettings = Field(..., description="高德地图API配置")
|
||||
log: Optional[LogSettings] = Field(default=LogSettings(), description="日志配置")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Config:
|
||||
"""配置管理器(单例模式)"""
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._config = None
|
||||
self._load_initial_config()
|
||||
self._initialized = True
|
||||
|
||||
def _load_initial_config(self):
|
||||
"""加载初始配置"""
|
||||
try:
|
||||
# 首先尝试从环境变量加载(Vercel 部署)
|
||||
if os.getenv("AMAP_API_KEY"):
|
||||
self._config = AppConfig(
|
||||
amap=AMapSettings(
|
||||
api_key=os.getenv("AMAP_API_KEY", ""),
|
||||
security_js_code=os.getenv("AMAP_SECURITY_JS_CODE", "")
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# 然后尝试从配置文件加载(本地开发)
|
||||
config_path = PROJECT_ROOT / "config" / "config.toml"
|
||||
if config_path.exists():
|
||||
with open(config_path, "rb") as f:
|
||||
toml_data = tomllib.load(f)
|
||||
|
||||
amap_config = toml_data.get("amap", {})
|
||||
if not amap_config.get("api_key"):
|
||||
raise ValueError("高德地图API密钥未配置")
|
||||
|
||||
self._config = AppConfig(
|
||||
amap=AMapSettings(**amap_config),
|
||||
log=LogSettings(**toml_data.get("log", {}))
|
||||
)
|
||||
else:
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||||
|
||||
except Exception as e:
|
||||
# 提供默认配置以防止启动失败
|
||||
print(f"配置加载失败,使用默认配置: {e}")
|
||||
self._config = AppConfig(
|
||||
amap=AMapSettings(
|
||||
api_key=os.getenv("AMAP_API_KEY", ""),
|
||||
security_js_code=os.getenv("AMAP_SECURITY_JS_CODE", "")
|
||||
)
|
||||
)
|
||||
|
||||
def reload(self):
|
||||
"""重新加载配置"""
|
||||
with self._lock:
|
||||
self._initialized = False
|
||||
self._load_initial_config()
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
def amap(self) -> AMapSettings:
|
||||
"""获取高德地图配置"""
|
||||
return self._config.amap
|
||||
|
||||
@property
|
||||
def log(self) -> LogSettings:
|
||||
"""获取日志配置"""
|
||||
return self._config.log
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
config = Config()
|
||||
2
MeetSpot/app/db/__init__.py
Normal file
2
MeetSpot/app/db/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""数据库相关模块初始化。"""
|
||||
|
||||
50
MeetSpot/app/db/crud.py
Normal file
50
MeetSpot/app/db/crud.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""常用数据库操作封装。"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
def _default_nickname(phone: str) -> str:
|
||||
suffix = phone[-4:] if len(phone) >= 4 else phone
|
||||
return f"用户{suffix}"
|
||||
|
||||
|
||||
async def get_user_by_phone(db: AsyncSession, phone: str) -> Optional[User]:
|
||||
"""根据手机号查询用户。"""
|
||||
stmt = select(User).where(User.phone == phone)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_user_by_id(db: AsyncSession, user_id: str) -> Optional[User]:
|
||||
"""根据ID查询用户。"""
|
||||
stmt = select(User).where(User.id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def create_user(
|
||||
db: AsyncSession, phone: str, nickname: Optional[str] = None, avatar_url: str = ""
|
||||
) -> User:
|
||||
"""创建新用户。"""
|
||||
user = User(
|
||||
phone=phone,
|
||||
nickname=nickname or _default_nickname(phone),
|
||||
avatar_url=avatar_url or "",
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def touch_last_login(db: AsyncSession, user: User) -> None:
|
||||
"""更新用户最近登录时间。"""
|
||||
user.last_login = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
48
MeetSpot/app/db/database.py
Normal file
48
MeetSpot/app/db/database.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""数据库引擎与会话管理。
|
||||
|
||||
使用SQLite作为MVP默认存储,保留通过环境变量`DATABASE_URL`切换到PostgreSQL的能力。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
|
||||
# 项目根目录,默认将SQLite数据库放在data目录下
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
DATA_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 允许通过环境变量覆盖数据库连接串
|
||||
DEFAULT_SQLITE_PATH = DATA_DIR / "meetspot.db"
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_SQLITE_PATH.as_posix()}"
|
||||
)
|
||||
|
||||
# 创建异步引擎与会话工厂
|
||||
engine = create_async_engine(DATABASE_URL, echo=False, future=True)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False
|
||||
)
|
||||
|
||||
# 统一的ORM基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI 依赖:提供数据库会话并确保正确关闭。"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""在启动时创建数据库表。"""
|
||||
# 延迟导入以避免循环依赖
|
||||
from app import models # noqa: F401 确保所有模型已注册
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
652
MeetSpot/app/design_tokens.py
Normal file
652
MeetSpot/app/design_tokens.py
Normal file
@@ -0,0 +1,652 @@
|
||||
"""
|
||||
MeetSpot Design Tokens - 单一真相来源
|
||||
|
||||
所有色彩、间距、字体系统的中心定义文件。
|
||||
修改本文件会影响:
|
||||
1. 基础模板 (templates/base.html)
|
||||
2. 静态HTML (public/*.html)
|
||||
3. 动态生成页面 (workspace/js_src/*.html)
|
||||
|
||||
WCAG 2.1 AA级对比度标准:
|
||||
- 正文: ≥ 4.5:1
|
||||
- 大文字: ≥ 3.0:1
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class DesignTokens:
|
||||
"""设计Token中心管理类"""
|
||||
|
||||
# ============================================================================
|
||||
# 全局品牌色 (Global Brand Colors) - MeetSpot 旅程主题
|
||||
# 这些颜色应用于所有页面的共同元素 (Header, Footer, 主按钮)
|
||||
# 配色理念:深海蓝(旅程与探索)+ 日落橙(会面的温暖)+ 薄荷绿(公平与平衡)
|
||||
# ============================================================================
|
||||
BRAND = {
|
||||
"primary": "#0A4D68", # 主色:深海蓝 - 沉稳、可信赖 (对比度: 9.12:1 ✓)
|
||||
"primary_dark": "#05445E", # 暗深海蓝 - 悬停态 (对比度: 11.83:1 ✓)
|
||||
"primary_light": "#088395", # 亮海蓝 - 装饰性元素 (对比度: 5.24:1 ✓)
|
||||
"gradient": "linear-gradient(135deg, #05445E 0%, #0A4D68 50%, #088395 100%)",
|
||||
# 强调色:日落橙 - 温暖、活力
|
||||
"accent": "#FF6B35", # 日落橙 - 主强调色 (对比度: 3.55:1, 大文字用途)
|
||||
"accent_light": "#FF8C61", # 亮橙 - 次要强调 (对比度: 2.87:1, 装饰用途)
|
||||
# 次要色:薄荷绿 - 清新、平衡
|
||||
"secondary": "#06D6A0", # 薄荷绿 (对比度: 2.28:1, 装饰用途)
|
||||
# 功能色 - 全部WCAG AA级
|
||||
"success": "#0C8A5D", # 成功绿 - 保持 (4.51:1 ✓)
|
||||
"info": "#2563EB", # 信息蓝 - 保持 (5.17:1 ✓)
|
||||
"warning": "#CA7205", # 警告橙 - 保持 (4.50:1 ✓)
|
||||
"error": "#DC2626", # 错误红 - 保持 (4.83:1 ✓)
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 文字颜色系统 (Text Colors)
|
||||
# 基于WCAG 2.1标准,所有文字色在白色背景上对比度 ≥ 4.5:1
|
||||
# ============================================================================
|
||||
TEXT = {
|
||||
"primary": "#111827", # 主文字 (gray-900, 对比度 17.74:1 ✓)
|
||||
"secondary": "#4B5563", # 次要文字 (gray-600, 对比度 7.56:1 ✓)
|
||||
"tertiary": "#6B7280", # 三级文字 (gray-500, 对比度 4.83:1 ✓)
|
||||
"muted": "#6B7280", # 弱化文字 - 修正 (原#9CA3AF: 2.54:1 -> 4.83:1, 使用tertiary色)
|
||||
"disabled": "#9CA3AF", # 禁用文字 - 保持低对比度 (装饰性文字允许 <3:1)
|
||||
"inverse": "#FFFFFF", # 反转文字 (深色背景上)
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 背景颜色系统 (Background Colors)
|
||||
# ============================================================================
|
||||
BACKGROUND = {
|
||||
"primary": "#FFFFFF", # 主背景 (白色)
|
||||
"secondary": "#F9FAFB", # 次要背景 (gray-50)
|
||||
"tertiary": "#F3F4F6", # 三级背景 (gray-100)
|
||||
"elevated": "#FFFFFF", # 卡片/浮动元素背景 (带阴影)
|
||||
"overlay": "rgba(0, 0, 0, 0.5)", # 蒙层
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 边框颜色系统 (Border Colors)
|
||||
# ============================================================================
|
||||
BORDER = {
|
||||
"default": "#E5E7EB", # 默认边框 (gray-200)
|
||||
"medium": "#D1D5DB", # 中等边框 (gray-300)
|
||||
"strong": "#9CA3AF", # 强边框 (gray-400)
|
||||
"focus": "#667EEA", # 焦点边框 (主品牌色)
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 阴影系统 (Shadow System)
|
||||
# ============================================================================
|
||||
SHADOW = {
|
||||
"sm": "0 1px 2px 0 rgba(0, 0, 0, 0.05)",
|
||||
"md": "0 4px 6px -1px rgba(0, 0, 0, 0.1)",
|
||||
"lg": "0 10px 15px -3px rgba(0, 0, 0, 0.1)",
|
||||
"xl": "0 20px 25px -5px rgba(0, 0, 0, 0.1)",
|
||||
"2xl": "0 25px 50px -12px rgba(0, 0, 0, 0.25)",
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 场所类型主题系统 (Venue Theme System)
|
||||
# 14种预设主题,动态注入到生成的推荐页面中
|
||||
#
|
||||
# 每个主题包含:
|
||||
# - theme_primary: 主色调 (Header背景、主按钮)
|
||||
# - theme_primary_light: 亮色变体 (悬停态、次要元素)
|
||||
# - theme_primary_dark: 暗色变体 (Active态、强调元素)
|
||||
# - theme_secondary: 辅助色 (图标、装饰元素)
|
||||
# - theme_light: 浅背景色 (卡片背景、Section背景)
|
||||
# - theme_dark: 深文字色 (标题、关键信息)
|
||||
#
|
||||
# WCAG验证: 所有theme_primary在白色背景上对比度 ≥ 3.0:1 (大文字)
|
||||
# 所有theme_dark在theme_light背景上对比度 ≥ 4.5:1 (正文)
|
||||
# ============================================================================
|
||||
VENUE_THEMES = {
|
||||
"咖啡馆": {
|
||||
"topic": "咖啡会",
|
||||
"icon_header": "bxs-coffee-togo",
|
||||
"icon_section": "bx-coffee",
|
||||
"icon_card": "bxs-coffee-alt",
|
||||
"map_legend": "咖啡馆",
|
||||
"noun_singular": "咖啡馆",
|
||||
"noun_plural": "咖啡馆",
|
||||
"theme_primary": "#8B5A3C", # 修正后的棕色 (原#9c6644对比度不足)
|
||||
"theme_primary_light": "#B8754A",
|
||||
"theme_primary_dark": "#6D4530",
|
||||
"theme_secondary": "#C9ADA7",
|
||||
"theme_light": "#F2E9E4",
|
||||
"theme_dark": "#1A1A2E", # 修正 (原#22223b对比度不足)
|
||||
},
|
||||
"图书馆": {
|
||||
"topic": "知书达理会",
|
||||
"icon_header": "bxs-book",
|
||||
"icon_section": "bx-book",
|
||||
"icon_card": "bxs-book-reader",
|
||||
"map_legend": "图书馆",
|
||||
"noun_singular": "图书馆",
|
||||
"noun_plural": "图书馆",
|
||||
"theme_primary": "#3A5A8A", # 修正后的蓝色 (原#4a6fa5对比度不足)
|
||||
"theme_primary_light": "#5B7FB5",
|
||||
"theme_primary_dark": "#2B4469",
|
||||
"theme_secondary": "#9DC0E5",
|
||||
"theme_light": "#F0F5FA",
|
||||
"theme_dark": "#1F2937", # 修正
|
||||
},
|
||||
"餐厅": {
|
||||
"topic": "美食汇",
|
||||
"icon_header": "bxs-restaurant",
|
||||
"icon_section": "bx-restaurant",
|
||||
"icon_card": "bxs-restaurant",
|
||||
"map_legend": "餐厅",
|
||||
"noun_singular": "餐厅",
|
||||
"noun_plural": "餐厅",
|
||||
"theme_primary": "#C13B2A", # 修正后的红色 (原#e74c3c过亮)
|
||||
"theme_primary_light": "#E15847",
|
||||
"theme_primary_dark": "#9A2F22",
|
||||
"theme_secondary": "#FADBD8",
|
||||
"theme_light": "#FEF5E7",
|
||||
"theme_dark": "#2C1618", # 修正
|
||||
},
|
||||
"商场": {
|
||||
"topic": "乐购汇",
|
||||
"icon_header": "bxs-shopping-bag",
|
||||
"icon_section": "bx-shopping-bag",
|
||||
"icon_card": "bxs-store-alt",
|
||||
"map_legend": "商场",
|
||||
"noun_singular": "商场",
|
||||
"noun_plural": "商场",
|
||||
"theme_primary": "#6D3588", # 修正后的紫色 (原#8e44ad过亮)
|
||||
"theme_primary_light": "#8F57AC",
|
||||
"theme_primary_dark": "#542969",
|
||||
"theme_secondary": "#D7BDE2",
|
||||
"theme_light": "#F4ECF7",
|
||||
"theme_dark": "#2D1A33", # 修正
|
||||
},
|
||||
"公园": {
|
||||
"topic": "悠然汇",
|
||||
"icon_header": "bxs-tree",
|
||||
"icon_section": "bx-leaf",
|
||||
"icon_card": "bxs-florist",
|
||||
"map_legend": "公园",
|
||||
"noun_singular": "公园",
|
||||
"noun_plural": "公园",
|
||||
"theme_primary": "#1E8B4D", # 修正后的绿色 (原#27ae60过亮)
|
||||
"theme_primary_light": "#48B573",
|
||||
"theme_primary_dark": "#176A3A",
|
||||
"theme_secondary": "#A9DFBF",
|
||||
"theme_light": "#EAFAF1",
|
||||
"theme_dark": "#1C3020", # 修正
|
||||
},
|
||||
"电影院": {
|
||||
"topic": "光影汇",
|
||||
"icon_header": "bxs-film",
|
||||
"icon_section": "bx-film",
|
||||
"icon_card": "bxs-movie-play",
|
||||
"map_legend": "电影院",
|
||||
"noun_singular": "电影院",
|
||||
"noun_plural": "电影院",
|
||||
"theme_primary": "#2C3E50", # 保持 (对比度合格)
|
||||
"theme_primary_light": "#4D5D6E",
|
||||
"theme_primary_dark": "#1F2D3D",
|
||||
"theme_secondary": "#AEB6BF",
|
||||
"theme_light": "#EBEDEF",
|
||||
"theme_dark": "#0F1419", # 修正
|
||||
},
|
||||
"篮球场": {
|
||||
"topic": "篮球部落",
|
||||
"icon_header": "bxs-basketball",
|
||||
"icon_section": "bx-basketball",
|
||||
"icon_card": "bxs-basketball",
|
||||
"map_legend": "篮球场",
|
||||
"noun_singular": "篮球场",
|
||||
"noun_plural": "篮球场",
|
||||
"theme_primary": "#CA7F0E", # 二次修正 (原#D68910: 2.82:1 -> 3.06:1 for large text)
|
||||
"theme_primary_light": "#E89618",
|
||||
"theme_primary_dark": "#A3670B",
|
||||
"theme_secondary": "#FDEBD0",
|
||||
"theme_light": "#FEF9E7",
|
||||
"theme_dark": "#3A2303", # 已修正 ✓
|
||||
},
|
||||
"健身房": {
|
||||
"topic": "健身汇",
|
||||
"icon_header": "bx-dumbbell",
|
||||
"icon_section": "bx-dumbbell",
|
||||
"icon_card": "bx-dumbbell",
|
||||
"map_legend": "健身房",
|
||||
"noun_singular": "健身房",
|
||||
"noun_plural": "健身房",
|
||||
"theme_primary": "#C5671A", # 修正后的橙色 (原#e67e22过亮)
|
||||
"theme_primary_light": "#E17E2E",
|
||||
"theme_primary_dark": "#9E5315",
|
||||
"theme_secondary": "#FDEBD0",
|
||||
"theme_light": "#FEF9E7",
|
||||
"theme_dark": "#3A2303", # 修正
|
||||
},
|
||||
"KTV": {
|
||||
"topic": "欢唱汇",
|
||||
"icon_header": "bxs-microphone",
|
||||
"icon_section": "bx-microphone",
|
||||
"icon_card": "bxs-microphone",
|
||||
"map_legend": "KTV",
|
||||
"noun_singular": "KTV",
|
||||
"noun_plural": "KTV",
|
||||
"theme_primary": "#D10F6F", # 修正后的粉色 (原#FF1493过亮)
|
||||
"theme_primary_light": "#F03A8A",
|
||||
"theme_primary_dark": "#A50C58",
|
||||
"theme_secondary": "#FFB6C1",
|
||||
"theme_light": "#FFF0F5",
|
||||
"theme_dark": "#6B0A2E", # 修正
|
||||
},
|
||||
"博物馆": {
|
||||
"topic": "博古汇",
|
||||
"icon_header": "bxs-institution",
|
||||
"icon_section": "bx-institution",
|
||||
"icon_card": "bxs-institution",
|
||||
"map_legend": "博物馆",
|
||||
"noun_singular": "博物馆",
|
||||
"noun_plural": "博物馆",
|
||||
"theme_primary": "#A88517", # 二次修正 (原#B8941A: 2.88:1 -> 3.21:1 for large text)
|
||||
"theme_primary_light": "#C29E1D",
|
||||
"theme_primary_dark": "#896B13",
|
||||
"theme_secondary": "#F0E68C",
|
||||
"theme_light": "#FFFACD",
|
||||
"theme_dark": "#6B5535", # 已修正 ✓
|
||||
},
|
||||
"景点": {
|
||||
"topic": "游览汇",
|
||||
"icon_header": "bxs-landmark",
|
||||
"icon_section": "bx-landmark",
|
||||
"icon_card": "bxs-landmark",
|
||||
"map_legend": "景点",
|
||||
"noun_singular": "景点",
|
||||
"noun_plural": "景点",
|
||||
"theme_primary": "#138496", # 保持 (对比度合格)
|
||||
"theme_primary_light": "#20A5BB",
|
||||
"theme_primary_dark": "#0F6875",
|
||||
"theme_secondary": "#7FDBDA",
|
||||
"theme_light": "#E0F7FA",
|
||||
"theme_dark": "#00504A", # 修正
|
||||
},
|
||||
"酒吧": {
|
||||
"topic": "夜宴汇",
|
||||
"icon_header": "bxs-drink",
|
||||
"icon_section": "bx-drink",
|
||||
"icon_card": "bxs-drink",
|
||||
"map_legend": "酒吧",
|
||||
"noun_singular": "酒吧",
|
||||
"noun_plural": "酒吧",
|
||||
"theme_primary": "#2C3E50", # 保持 (对比度合格)
|
||||
"theme_primary_light": "#4D5D6E",
|
||||
"theme_primary_dark": "#1B2631",
|
||||
"theme_secondary": "#85929E",
|
||||
"theme_light": "#EBF5FB",
|
||||
"theme_dark": "#0C1014", # 修正
|
||||
},
|
||||
"茶楼": {
|
||||
"topic": "茶韵汇",
|
||||
"icon_header": "bxs-coffee-bean",
|
||||
"icon_section": "bx-coffee-bean",
|
||||
"icon_card": "bxs-coffee-bean",
|
||||
"map_legend": "茶楼",
|
||||
"noun_singular": "茶楼",
|
||||
"noun_plural": "茶楼",
|
||||
"theme_primary": "#406058", # 修正后的绿色 (原#52796F过亮)
|
||||
"theme_primary_light": "#567A6F",
|
||||
"theme_primary_dark": "#2F4841",
|
||||
"theme_secondary": "#CAD2C5",
|
||||
"theme_light": "#F7F9F7",
|
||||
"theme_dark": "#1F2D28", # 修正
|
||||
},
|
||||
"游泳馆": { # 新增第14个主题
|
||||
"topic": "泳池汇",
|
||||
"icon_header": "bx-swim",
|
||||
"icon_section": "bx-swim",
|
||||
"icon_card": "bx-swim",
|
||||
"map_legend": "游泳馆",
|
||||
"noun_singular": "游泳馆",
|
||||
"noun_plural": "游泳馆",
|
||||
"theme_primary": "#1E90FF", # 水蓝色
|
||||
"theme_primary_light": "#4DA6FF",
|
||||
"theme_primary_dark": "#1873CC",
|
||||
"theme_secondary": "#87CEEB",
|
||||
"theme_light": "#E0F2FE",
|
||||
"theme_dark": "#0C4A6E",
|
||||
},
|
||||
# 默认主题 (与咖啡馆相同)
|
||||
"default": {
|
||||
"topic": "推荐地点",
|
||||
"icon_header": "bx-map-pin",
|
||||
"icon_section": "bx-location-plus",
|
||||
"icon_card": "bx-map-alt",
|
||||
"map_legend": "推荐地点",
|
||||
"noun_singular": "地点",
|
||||
"noun_plural": "地点",
|
||||
"theme_primary": "#8B5A3C",
|
||||
"theme_primary_light": "#B8754A",
|
||||
"theme_primary_dark": "#6D4530",
|
||||
"theme_secondary": "#C9ADA7",
|
||||
"theme_light": "#F2E9E4",
|
||||
"theme_dark": "#1A1A2E",
|
||||
},
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 间距系统 (Spacing System)
|
||||
# 基于8px基准的间距尺度
|
||||
# ============================================================================
|
||||
SPACING = {
|
||||
"0": "0",
|
||||
"1": "4px", # 0.25rem
|
||||
"2": "8px", # 0.5rem
|
||||
"3": "12px", # 0.75rem
|
||||
"4": "16px", # 1rem
|
||||
"5": "20px", # 1.25rem
|
||||
"6": "24px", # 1.5rem
|
||||
"8": "32px", # 2rem
|
||||
"10": "40px", # 2.5rem
|
||||
"12": "48px", # 3rem
|
||||
"16": "64px", # 4rem
|
||||
"20": "80px", # 5rem
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 圆角系统 (Border Radius System)
|
||||
# ============================================================================
|
||||
RADIUS = {
|
||||
"none": "0",
|
||||
"sm": "4px",
|
||||
"md": "8px",
|
||||
"lg": "12px",
|
||||
"xl": "16px",
|
||||
"2xl": "24px",
|
||||
"full": "9999px",
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 字体系统 (Typography System) - MeetSpot 品牌字体
|
||||
# Poppins (标题) - 友好且现代,比 Inter 更有个性
|
||||
# Nunito (正文) - 圆润易读,传递温暖感
|
||||
# ============================================================================
|
||||
FONT = {
|
||||
"family_heading": '"Poppins", "PingFang SC", -apple-system, BlinkMacSystemFont, sans-serif',
|
||||
"family_sans": '"Nunito", "Microsoft YaHei", -apple-system, BlinkMacSystemFont, sans-serif',
|
||||
"family_mono": '"JetBrains Mono", "Fira Code", "SF Mono", "Consolas", "Monaco", monospace',
|
||||
# 字体大小 (基于16px基准)
|
||||
"size_xs": "0.75rem", # 12px
|
||||
"size_sm": "0.875rem", # 14px
|
||||
"size_base": "1rem", # 16px
|
||||
"size_lg": "1.125rem", # 18px
|
||||
"size_xl": "1.25rem", # 20px
|
||||
"size_2xl": "1.5rem", # 24px
|
||||
"size_3xl": "1.875rem", # 30px
|
||||
"size_4xl": "2.25rem", # 36px
|
||||
# 字重
|
||||
"weight_normal": "400",
|
||||
"weight_medium": "500",
|
||||
"weight_semibold": "600",
|
||||
"weight_bold": "700",
|
||||
# 行高
|
||||
"leading_tight": "1.25",
|
||||
"leading_normal": "1.5",
|
||||
"leading_relaxed": "1.7",
|
||||
"leading_loose": "2",
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Z-Index系统 (Layering System)
|
||||
# ============================================================================
|
||||
Z_INDEX = {
|
||||
"dropdown": "1000",
|
||||
"sticky": "1020",
|
||||
"fixed": "1030",
|
||||
"modal_backdrop": "1040",
|
||||
"modal": "1050",
|
||||
"popover": "1060",
|
||||
"tooltip": "1070",
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 交互动画系统 (Interaction Animations)
|
||||
# 遵循WCAG 2.1 - 支持prefers-reduced-motion
|
||||
# ============================================================================
|
||||
ANIMATIONS = """
|
||||
/* ========== 交互动画系统 (Interaction Animations) ========== */
|
||||
|
||||
/* Button动画 - 200ms ease-out过渡 */
|
||||
button, .btn, input[type="submit"], a.button {
|
||||
transition: all 0.2s ease-out;
|
||||
}
|
||||
|
||||
button:hover, .btn:hover, input[type="submit"]:hover, a.button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: var(--shadow-lg);
|
||||
}
|
||||
|
||||
button:active, .btn:active, input[type="submit"]:active, a.button:active {
|
||||
transform: translateY(0);
|
||||
box-shadow: var(--shadow-md);
|
||||
}
|
||||
|
||||
button:focus, .btn:focus, input[type="submit"]:focus, a.button:focus {
|
||||
outline: 2px solid var(--brand-primary);
|
||||
outline-offset: 2px;
|
||||
}
|
||||
|
||||
/* Loading Spinner动画 */
|
||||
.loading::after {
|
||||
content: "";
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
margin-left: 8px;
|
||||
border: 2px solid var(--brand-primary);
|
||||
border-top-color: transparent;
|
||||
border-radius: 50%;
|
||||
display: inline-block;
|
||||
animation: spin 0.6s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
/* Card悬停效果 - 微妙的缩放和阴影提升 */
|
||||
.card, .venue-card, .recommendation-card {
|
||||
transition: transform 0.2s ease-out, box-shadow 0.2s ease-out;
|
||||
}
|
||||
|
||||
.card:hover, .venue-card:hover, .recommendation-card:hover {
|
||||
transform: scale(1.02);
|
||||
box-shadow: var(--shadow-xl);
|
||||
}
|
||||
|
||||
/* Fade-in渐显动画 - 400ms */
|
||||
.fade-in {
|
||||
animation: fadeIn 0.4s ease-out;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
/* Slide-in滑入动画 */
|
||||
.slide-in {
|
||||
animation: slideIn 0.4s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateX(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateX(0);
|
||||
}
|
||||
}
|
||||
|
||||
/* WCAG 2.1无障碍支持 - 尊重用户的动画偏好 */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
*,
|
||||
*::before,
|
||||
*::after {
|
||||
animation-duration: 0.01ms !important;
|
||||
animation-iteration-count: 1 !important;
|
||||
transition-duration: 0.01ms !important;
|
||||
scroll-behavior: auto !important;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# ============================================================================
|
||||
# 辅助方法
|
||||
# ============================================================================
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=128)
|
||||
def get_venue_theme(cls, venue_type: str) -> Dict[str, str]:
|
||||
"""
|
||||
根据场所类型获取主题配置
|
||||
|
||||
Args:
|
||||
venue_type: 场所类型 (如"咖啡馆"、"图书馆")
|
||||
|
||||
Returns:
|
||||
包含主题色彩和图标的字典
|
||||
|
||||
Example:
|
||||
>>> theme = DesignTokens.get_venue_theme("咖啡馆")
|
||||
>>> print(theme['theme_primary']) # "#8B5A3C"
|
||||
"""
|
||||
return cls.VENUE_THEMES.get(venue_type, cls.VENUE_THEMES["default"])
|
||||
|
||||
@classmethod
|
||||
def to_css_variables(cls) -> str:
|
||||
"""
|
||||
将设计token转换为CSS变量字符串
|
||||
|
||||
Returns:
|
||||
可直接嵌入<style>标签的CSS变量定义
|
||||
|
||||
Example:
|
||||
>>> css = DesignTokens.to_css_variables()
|
||||
>>> print(css)
|
||||
:root {
|
||||
--brand-primary: #667EEA;
|
||||
--brand-primary-dark: #764BA2;
|
||||
...
|
||||
}
|
||||
"""
|
||||
lines = [":root {"]
|
||||
|
||||
# 品牌色
|
||||
for key, value in cls.BRAND.items():
|
||||
css_key = f"--brand-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# 文字色
|
||||
for key, value in cls.TEXT.items():
|
||||
css_key = f"--text-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# 背景色
|
||||
for key, value in cls.BACKGROUND.items():
|
||||
css_key = f"--bg-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# 边框色
|
||||
for key, value in cls.BORDER.items():
|
||||
css_key = f"--border-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# 阴影
|
||||
for key, value in cls.SHADOW.items():
|
||||
css_key = f"--shadow-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# 间距
|
||||
for key, value in cls.SPACING.items():
|
||||
css_key = f"--spacing-{key}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# 圆角
|
||||
for key, value in cls.RADIUS.items():
|
||||
css_key = f"--radius-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# 字体
|
||||
for key, value in cls.FONT.items():
|
||||
css_key = f"--font-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
# Z-Index
|
||||
for key, value in cls.Z_INDEX.items():
|
||||
css_key = f"--z-{key.replace('_', '-')}"
|
||||
lines.append(f" {css_key}: {value};")
|
||||
|
||||
lines.append("}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@classmethod
|
||||
def generate_css_file(cls, output_path: str = "static/css/design-tokens.css"):
|
||||
"""
|
||||
生成独立的CSS设计token文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Example:
|
||||
>>> DesignTokens.generate_css_file()
|
||||
# 生成 static/css/design-tokens.css
|
||||
"""
|
||||
import os
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("/* ============================================\n")
|
||||
f.write(" * MeetSpot Design Tokens\n")
|
||||
f.write(" * 自动生成 - 请勿手动编辑\n")
|
||||
f.write(" * 生成源: app/design_tokens.py\n")
|
||||
f.write(" * ==========================================*/\n\n")
|
||||
f.write(cls.to_css_variables())
|
||||
f.write("\n\n/* Compatibility fallbacks for older browsers */\n")
|
||||
f.write(".no-cssvar {\n")
|
||||
f.write(" /* Fallback for browsers without CSS variable support */\n")
|
||||
f.write(f" color: {cls.TEXT['primary']};\n")
|
||||
f.write(f" background-color: {cls.BACKGROUND['primary']};\n")
|
||||
f.write("}\n\n")
|
||||
# 追加交互动画系统
|
||||
f.write(cls.ANIMATIONS)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 全局单例访问 (方便快速引用)
|
||||
# ============================================================================
|
||||
COLORS = {
|
||||
"brand": DesignTokens.BRAND,
|
||||
"text": DesignTokens.TEXT,
|
||||
"background": DesignTokens.BACKGROUND,
|
||||
"border": DesignTokens.BORDER,
|
||||
}
|
||||
|
||||
VENUE_THEMES = DesignTokens.VENUE_THEMES
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数
|
||||
# ============================================================================
|
||||
def get_venue_theme(venue_type: str) -> Dict[str, str]:
|
||||
"""便捷函数: 获取场所主题"""
|
||||
return DesignTokens.get_venue_theme(venue_type)
|
||||
|
||||
|
||||
def generate_design_tokens_css(output_path: str = "static/css/design-tokens.css"):
|
||||
"""便捷函数: 生成CSS文件"""
|
||||
DesignTokens.generate_css_file(output_path)
|
||||
13
MeetSpot/app/exceptions.py
Normal file
13
MeetSpot/app/exceptions.py
Normal file
@@ -0,0 +1,13 @@
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
class OpenManusError(Exception):
|
||||
"""Base exception for all OpenManus errors"""
|
||||
|
||||
|
||||
class TokenLimitExceeded(OpenManusError):
|
||||
"""Exception raised when the token limit is exceeded"""
|
||||
800
MeetSpot/app/llm.py
Normal file
800
MeetSpot/app/llm.py
Normal file
@@ -0,0 +1,800 @@
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import tiktoken
|
||||
from openai import (APIError, AsyncAzureOpenAI, AsyncOpenAI,
|
||||
AuthenticationError, OpenAIError, RateLimitError)
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from tenacity import (retry, retry_if_exception_type, stop_after_attempt,
|
||||
wait_random_exponential)
|
||||
|
||||
from app.config import LLMSettings, config
|
||||
from app.exceptions import TokenLimitExceeded
|
||||
from app.logger import logger # Assuming a logger is set up in your app
|
||||
from app.schema import (ROLE_VALUES, TOOL_CHOICE_TYPE, TOOL_CHOICE_VALUES,
|
||||
Message, ToolChoice)
|
||||
|
||||
REASONING_MODELS = ["o1", "o3-mini"]
|
||||
MULTIMODAL_MODELS = [
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
# Token constants
|
||||
BASE_MESSAGE_TOKENS = 4
|
||||
FORMAT_TOKENS = 2
|
||||
LOW_DETAIL_IMAGE_TOKENS = 85
|
||||
HIGH_DETAIL_TILE_TOKENS = 170
|
||||
|
||||
# Image processing constants
|
||||
MAX_SIZE = 2048
|
||||
HIGH_DETAIL_TARGET_SHORT_SIDE = 768
|
||||
TILE_SIZE = 512
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def count_text(self, text: str) -> int:
|
||||
"""Calculate tokens for a text string"""
|
||||
return 0 if not text else len(self.tokenizer.encode(text))
|
||||
|
||||
def count_image(self, image_item: dict) -> int:
|
||||
"""
|
||||
Calculate tokens for an image based on detail level and dimensions
|
||||
|
||||
For "low" detail: fixed 85 tokens
|
||||
For "high" detail:
|
||||
1. Scale to fit in 2048x2048 square
|
||||
2. Scale shortest side to 768px
|
||||
3. Count 512px tiles (170 tokens each)
|
||||
4. Add 85 tokens
|
||||
"""
|
||||
detail = image_item.get("detail", "medium")
|
||||
|
||||
# For low detail, always return fixed token count
|
||||
if detail == "low":
|
||||
return self.LOW_DETAIL_IMAGE_TOKENS
|
||||
|
||||
# For medium detail (default in OpenAI), use high detail calculation
|
||||
# OpenAI doesn't specify a separate calculation for medium
|
||||
|
||||
# For high detail, calculate based on dimensions if available
|
||||
if detail == "high" or detail == "medium":
|
||||
# If dimensions are provided in the image_item
|
||||
if "dimensions" in image_item:
|
||||
width, height = image_item["dimensions"]
|
||||
return self._calculate_high_detail_tokens(width, height)
|
||||
|
||||
# Default values when dimensions aren't available or detail level is unknown
|
||||
if detail == "high":
|
||||
# Default to a 1024x1024 image calculation for high detail
|
||||
return self._calculate_high_detail_tokens(1024, 1024) # 765 tokens
|
||||
elif detail == "medium":
|
||||
# Default to a medium-sized image for medium detail
|
||||
return 1024 # This matches the original default
|
||||
else:
|
||||
# For unknown detail levels, use medium as default
|
||||
return 1024
|
||||
|
||||
def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
|
||||
"""Calculate tokens for high detail images based on dimensions"""
|
||||
# Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
|
||||
if width > self.MAX_SIZE or height > self.MAX_SIZE:
|
||||
scale = self.MAX_SIZE / max(width, height)
|
||||
width = int(width * scale)
|
||||
height = int(height * scale)
|
||||
|
||||
# Step 2: Scale so shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
|
||||
scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
|
||||
scaled_width = int(width * scale)
|
||||
scaled_height = int(height * scale)
|
||||
|
||||
# Step 3: Count number of 512px tiles
|
||||
tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
|
||||
tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
|
||||
total_tiles = tiles_x * tiles_y
|
||||
|
||||
# Step 4: Calculate final token count
|
||||
return (
|
||||
total_tiles * self.HIGH_DETAIL_TILE_TOKENS
|
||||
) + self.LOW_DETAIL_IMAGE_TOKENS
|
||||
|
||||
def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
|
||||
"""Calculate tokens for message content"""
|
||||
if not content:
|
||||
return 0
|
||||
|
||||
if isinstance(content, str):
|
||||
return self.count_text(content)
|
||||
|
||||
token_count = 0
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
token_count += self.count_text(item)
|
||||
elif isinstance(item, dict):
|
||||
if "text" in item:
|
||||
token_count += self.count_text(item["text"])
|
||||
elif "image_url" in item:
|
||||
token_count += self.count_image(item)
|
||||
return token_count
|
||||
|
||||
def count_tool_calls(self, tool_calls: List[dict]) -> int:
|
||||
"""Calculate tokens for tool calls"""
|
||||
token_count = 0
|
||||
for tool_call in tool_calls:
|
||||
if "function" in tool_call:
|
||||
function = tool_call["function"]
|
||||
token_count += self.count_text(function.get("name", ""))
|
||||
token_count += self.count_text(function.get("arguments", ""))
|
||||
return token_count
|
||||
|
||||
def count_message_tokens(self, messages: List[dict]) -> int:
|
||||
"""Calculate the total number of tokens in a message list"""
|
||||
total_tokens = self.FORMAT_TOKENS # Base format tokens
|
||||
|
||||
for message in messages:
|
||||
tokens = self.BASE_MESSAGE_TOKENS # Base tokens per message
|
||||
|
||||
# Add role tokens
|
||||
tokens += self.count_text(message.get("role", ""))
|
||||
|
||||
# Add content tokens
|
||||
if "content" in message:
|
||||
tokens += self.count_content(message["content"])
|
||||
|
||||
# Add tool calls tokens
|
||||
if "tool_calls" in message:
|
||||
tokens += self.count_tool_calls(message["tool_calls"])
|
||||
|
||||
# Add name and tool_call_id tokens
|
||||
tokens += self.count_text(message.get("name", ""))
|
||||
tokens += self.count_text(message.get("tool_call_id", ""))
|
||||
|
||||
total_tokens += tokens
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
class LLM:
|
||||
_instances: Dict[str, "LLM"] = {}
|
||||
|
||||
def __new__(
|
||||
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||
):
|
||||
if config_name not in cls._instances:
|
||||
instance = super().__new__(cls)
|
||||
instance.__init__(config_name, llm_config)
|
||||
cls._instances[config_name] = instance
|
||||
return cls._instances[config_name]
|
||||
|
||||
def __init__(
|
||||
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||
):
|
||||
if not hasattr(self, "client"): # Only initialize if not already initialized
|
||||
llm_config = llm_config or config.llm
|
||||
llm_config = llm_config.get(config_name, llm_config["default"])
|
||||
self.model = llm_config.model
|
||||
self.max_tokens = llm_config.max_tokens
|
||||
self.temperature = llm_config.temperature
|
||||
self.api_type = llm_config.api_type
|
||||
self.api_key = llm_config.api_key
|
||||
self.api_version = llm_config.api_version
|
||||
self.base_url = llm_config.base_url
|
||||
|
||||
# Add token counting related attributes
|
||||
self.total_input_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
self.max_input_tokens = (
|
||||
llm_config.max_input_tokens
|
||||
if hasattr(llm_config, "max_input_tokens")
|
||||
else None
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
try:
|
||||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||||
except KeyError:
|
||||
# If the model is not in tiktoken's presets, use cl100k_base as default
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
if self.api_type == "azure":
|
||||
self.client = AsyncAzureOpenAI(
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
else:
|
||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
self.token_counter = TokenCounter(self.tokenizer)
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Calculate the number of tokens in a text"""
|
||||
if not text:
|
||||
return 0
|
||||
return len(self.tokenizer.encode(text))
|
||||
|
||||
def count_message_tokens(self, messages: List[dict]) -> int:
|
||||
return self.token_counter.count_message_tokens(messages)
|
||||
|
||||
def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None:
|
||||
"""Update token counts"""
|
||||
# Only track tokens if max_input_tokens is set
|
||||
self.total_input_tokens += input_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
logger.info(
|
||||
f"Token usage: Input={input_tokens}, Completion={completion_tokens}, "
|
||||
f"Cumulative Input={self.total_input_tokens}, Cumulative Completion={self.total_completion_tokens}, "
|
||||
f"Total={input_tokens + completion_tokens}, Cumulative Total={self.total_input_tokens + self.total_completion_tokens}"
|
||||
)
|
||||
|
||||
def check_token_limit(self, input_tokens: int) -> bool:
|
||||
"""Check if token limits are exceeded"""
|
||||
if self.max_input_tokens is not None:
|
||||
return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
|
||||
# If max_input_tokens is not set, always return True
|
||||
return True
|
||||
|
||||
def get_limit_error_message(self, input_tokens: int) -> str:
|
||||
"""Generate error message for token limit exceeded"""
|
||||
if (
|
||||
self.max_input_tokens is not None
|
||||
and (self.total_input_tokens + input_tokens) > self.max_input_tokens
|
||||
):
|
||||
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
|
||||
|
||||
return "Token limit exceeded"
|
||||
|
||||
@staticmethod
|
||||
def format_messages(
|
||||
messages: List[Union[dict, Message]], supports_images: bool = False
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Format messages for LLM by converting them to OpenAI message format.
|
||||
|
||||
Args:
|
||||
messages: List of messages that can be either dict or Message objects
|
||||
supports_images: Flag indicating if the target model supports image inputs
|
||||
|
||||
Returns:
|
||||
List[dict]: List of formatted messages in OpenAI format
|
||||
|
||||
Raises:
|
||||
ValueError: If messages are invalid or missing required fields
|
||||
TypeError: If unsupported message types are provided
|
||||
|
||||
Examples:
|
||||
>>> msgs = [
|
||||
... Message.system_message("You are a helpful assistant"),
|
||||
... {"role": "user", "content": "Hello"},
|
||||
... Message.user_message("How are you?")
|
||||
... ]
|
||||
>>> formatted = LLM.format_messages(msgs)
|
||||
"""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
# Convert Message objects to dictionaries
|
||||
if isinstance(message, Message):
|
||||
message = message.to_dict()
|
||||
|
||||
if isinstance(message, dict):
|
||||
# If message is a dict, ensure it has required fields
|
||||
if "role" not in message:
|
||||
raise ValueError("Message dict must contain 'role' field")
|
||||
|
||||
# Process base64 images if present and model supports images
|
||||
if supports_images and message.get("base64_image"):
|
||||
# Initialize or convert content to appropriate format
|
||||
if not message.get("content"):
|
||||
message["content"] = []
|
||||
elif isinstance(message["content"], str):
|
||||
message["content"] = [
|
||||
{"type": "text", "text": message["content"]}
|
||||
]
|
||||
elif isinstance(message["content"], list):
|
||||
# Convert string items to proper text objects
|
||||
message["content"] = [
|
||||
(
|
||||
{"type": "text", "text": item}
|
||||
if isinstance(item, str)
|
||||
else item
|
||||
)
|
||||
for item in message["content"]
|
||||
]
|
||||
|
||||
# Add the image to content
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{message['base64_image']}"
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Remove the base64_image field
|
||||
del message["base64_image"]
|
||||
# If model doesn't support images but message has base64_image, handle gracefully
|
||||
elif not supports_images and message.get("base64_image"):
|
||||
# Just remove the base64_image field and keep the text content
|
||||
del message["base64_image"]
|
||||
|
||||
if "content" in message or "tool_calls" in message:
|
||||
formatted_messages.append(message)
|
||||
# else: do not include the message
|
||||
else:
|
||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
# Validate all messages have required fields
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] not in ROLE_VALUES:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
|
||||
return formatted_messages
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
retry=retry_if_exception_type(
|
||||
(OpenAIError, Exception, ValueError)
|
||||
), # Don't retry TokenLimitExceeded
|
||||
)
|
||||
async def ask(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
stream: bool = True,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the LLM and get the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
system_msgs: Optional system messages to prepend
|
||||
stream (bool): Whether to stream the response
|
||||
temperature (float): Sampling temperature for the response
|
||||
|
||||
Returns:
|
||||
str: The generated response
|
||||
|
||||
Raises:
|
||||
TokenLimitExceeded: If token limits are exceeded
|
||||
ValueError: If messages are invalid or response is empty
|
||||
OpenAIError: If API call fails after retries
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Check if the model supports images
|
||||
supports_images = self.model in MULTIMODAL_MODELS
|
||||
|
||||
# Format system and user messages with image support check
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||
messages = system_msgs + self.format_messages(messages, supports_images)
|
||||
else:
|
||||
messages = self.format_messages(messages, supports_images)
|
||||
|
||||
# Calculate input token count
|
||||
input_tokens = self.count_message_tokens(messages)
|
||||
|
||||
# Check if token limits are exceeded
|
||||
if not self.check_token_limit(input_tokens):
|
||||
error_message = self.get_limit_error_message(input_tokens)
|
||||
# Raise a special exception that won't be retried
|
||||
raise TokenLimitExceeded(error_message)
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if self.model in REASONING_MODELS:
|
||||
params["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
params["temperature"] = (
|
||||
temperature if temperature is not None else self.temperature
|
||||
)
|
||||
|
||||
if not stream:
|
||||
# Non-streaming request
|
||||
response = await self.client.chat.completions.create(
|
||||
**params, stream=False
|
||||
)
|
||||
|
||||
if not response.choices or not response.choices[0].message.content:
|
||||
raise ValueError("Empty or invalid response from LLM")
|
||||
|
||||
# Update token counts
|
||||
self.update_token_count(
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
# Streaming request, For streaming, update estimated token count before making the request
|
||||
self.update_token_count(input_tokens)
|
||||
|
||||
response = await self.client.chat.completions.create(**params, stream=True)
|
||||
|
||||
collected_messages = []
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or ""
|
||||
collected_messages.append(chunk_message)
|
||||
completion_text += chunk_message
|
||||
print(chunk_message, end="", flush=True)
|
||||
|
||||
print() # Newline after streaming
|
||||
full_response = "".join(collected_messages).strip()
|
||||
if not full_response:
|
||||
raise ValueError("Empty response from streaming LLM")
|
||||
|
||||
# estimate completion tokens for streaming response
|
||||
completion_tokens = self.count_tokens(completion_text)
|
||||
logger.info(
|
||||
f"Estimated completion tokens for streaming response: {completion_tokens}"
|
||||
)
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
return full_response
|
||||
|
||||
except TokenLimitExceeded:
|
||||
# Re-raise token limit errors without logging
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception(f"Validation error")
|
||||
raise
|
||||
except OpenAIError as oe:
|
||||
logger.exception(f"OpenAI API error")
|
||||
if isinstance(oe, AuthenticationError):
|
||||
logger.error("Authentication failed. Check API key.")
|
||||
elif isinstance(oe, RateLimitError):
|
||||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||||
elif isinstance(oe, APIError):
|
||||
logger.error(f"API error: {oe}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Unexpected error in ask")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
retry=retry_if_exception_type(
|
||||
(OpenAIError, Exception, ValueError)
|
||||
), # Don't retry TokenLimitExceeded
|
||||
)
|
||||
async def ask_with_images(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
images: List[Union[str, dict]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt with images to the LLM and get the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
images: List of image URLs or image data dictionaries
|
||||
system_msgs: Optional system messages to prepend
|
||||
stream (bool): Whether to stream the response
|
||||
temperature (float): Sampling temperature for the response
|
||||
|
||||
Returns:
|
||||
str: The generated response
|
||||
|
||||
Raises:
|
||||
TokenLimitExceeded: If token limits are exceeded
|
||||
ValueError: If messages are invalid or response is empty
|
||||
OpenAIError: If API call fails after retries
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# For ask_with_images, we always set supports_images to True because
|
||||
# this method should only be called with models that support images
|
||||
if self.model not in MULTIMODAL_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {self.model} does not support images. Use a model from {MULTIMODAL_MODELS}"
|
||||
)
|
||||
|
||||
# Format messages with image support
|
||||
formatted_messages = self.format_messages(messages, supports_images=True)
|
||||
|
||||
# Ensure the last message is from the user to attach images
|
||||
if not formatted_messages or formatted_messages[-1]["role"] != "user":
|
||||
raise ValueError(
|
||||
"The last message must be from the user to attach images"
|
||||
)
|
||||
|
||||
# Process the last user message to include images
|
||||
last_message = formatted_messages[-1]
|
||||
|
||||
# Convert content to multimodal format if needed
|
||||
content = last_message["content"]
|
||||
multimodal_content = (
|
||||
[{"type": "text", "text": content}]
|
||||
if isinstance(content, str)
|
||||
else content if isinstance(content, list) else []
|
||||
)
|
||||
|
||||
# Add images to content
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
multimodal_content.append(
|
||||
{"type": "image_url", "image_url": {"url": image}}
|
||||
)
|
||||
elif isinstance(image, dict) and "url" in image:
|
||||
multimodal_content.append({"type": "image_url", "image_url": image})
|
||||
elif isinstance(image, dict) and "image_url" in image:
|
||||
multimodal_content.append(image)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image format: {image}")
|
||||
|
||||
# Update the message with multimodal content
|
||||
last_message["content"] = multimodal_content
|
||||
|
||||
# Add system messages if provided
|
||||
if system_msgs:
|
||||
all_messages = (
|
||||
self.format_messages(system_msgs, supports_images=True)
|
||||
+ formatted_messages
|
||||
)
|
||||
else:
|
||||
all_messages = formatted_messages
|
||||
|
||||
# Calculate tokens and check limits
|
||||
input_tokens = self.count_message_tokens(all_messages)
|
||||
if not self.check_token_limit(input_tokens):
|
||||
raise TokenLimitExceeded(self.get_limit_error_message(input_tokens))
|
||||
|
||||
# Set up API parameters
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": all_messages,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if self.model in REASONING_MODELS:
|
||||
params["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
params["temperature"] = (
|
||||
temperature if temperature is not None else self.temperature
|
||||
)
|
||||
|
||||
# Handle non-streaming request
|
||||
if not stream:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
if not response.choices or not response.choices[0].message.content:
|
||||
raise ValueError("Empty or invalid response from LLM")
|
||||
|
||||
self.update_token_count(
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
# Handle streaming request
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
collected_messages = []
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or ""
|
||||
collected_messages.append(chunk_message)
|
||||
completion_text += chunk_message
|
||||
print(chunk_message, end="", flush=True)
|
||||
|
||||
print() # Newline after streaming
|
||||
full_response = "".join(collected_messages).strip()
|
||||
|
||||
if not full_response:
|
||||
raise ValueError("Empty response from streaming LLM")
|
||||
|
||||
completion_tokens = self.count_tokens(completion_text)
|
||||
logger.info(
|
||||
f"Estimated completion tokens for streaming response with images: {completion_tokens}"
|
||||
)
|
||||
self.update_token_count(input_tokens, completion_tokens)
|
||||
|
||||
return full_response
|
||||
|
||||
except TokenLimitExceeded:
|
||||
raise
|
||||
except ValueError as ve:
|
||||
logger.error(f"Validation error in ask_with_images: {ve}")
|
||||
raise
|
||||
except OpenAIError as oe:
|
||||
logger.error(f"OpenAI API error: {oe}")
|
||||
if isinstance(oe, AuthenticationError):
|
||||
logger.error("Authentication failed. Check API key.")
|
||||
elif isinstance(oe, RateLimitError):
|
||||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||||
elif isinstance(oe, APIError):
|
||||
logger.error(f"API error: {oe}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in ask_with_images: {e}")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
retry=retry_if_exception_type(
|
||||
(OpenAIError, Exception, ValueError)
|
||||
), # Don't retry TokenLimitExceeded
|
||||
)
|
||||
async def ask_tool(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
timeout: int = 300,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> ChatCompletionMessage | None:
|
||||
"""
|
||||
Ask LLM using functions/tools and return the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
system_msgs: Optional system messages to prepend
|
||||
timeout: Request timeout in seconds
|
||||
tools: List of tools to use
|
||||
tool_choice: Tool choice strategy
|
||||
temperature: Sampling temperature for the response
|
||||
**kwargs: Additional completion arguments
|
||||
|
||||
Returns:
|
||||
ChatCompletionMessage: The model's response
|
||||
|
||||
Raises:
|
||||
TokenLimitExceeded: If token limits are exceeded
|
||||
ValueError: If tools, tool_choice, or messages are invalid
|
||||
OpenAIError: If API call fails after retries
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Validate tool_choice
|
||||
if tool_choice not in TOOL_CHOICE_VALUES:
|
||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||
|
||||
# Check if the model supports images
|
||||
supports_images = self.model in MULTIMODAL_MODELS
|
||||
|
||||
# Format messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||
formatted_messages = system_msgs + self.format_messages(messages, supports_images)
|
||||
else:
|
||||
formatted_messages = self.format_messages(messages, supports_images)
|
||||
|
||||
# 验证消息序列,确保tool消息前面有带tool_calls的assistant消息
|
||||
valid_messages = []
|
||||
tool_calls_ids = set() # 跟踪所有有效的tool_call IDs
|
||||
|
||||
for i, msg in enumerate(formatted_messages):
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role")
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
else:
|
||||
role = msg.role if hasattr(msg, "role") else None
|
||||
tool_call_id = msg.tool_call_id if hasattr(msg, "tool_call_id") else None
|
||||
|
||||
# 如果是tool消息,验证它引用的tool_call_id是否存在
|
||||
if role == "tool" and tool_call_id:
|
||||
if tool_call_id not in tool_calls_ids:
|
||||
logger.warning(f"发现无效的工具消息 - tool_call_id '{tool_call_id}' 未找到关联的tool_calls,将跳过此消息")
|
||||
continue
|
||||
|
||||
# 如果是assistant消息且有tool_calls,记录所有tool_call IDs
|
||||
elif role == "assistant":
|
||||
tool_calls = []
|
||||
if isinstance(msg, dict) and "tool_calls" in msg:
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
elif hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
tool_calls = msg.tool_calls
|
||||
|
||||
# 收集所有工具调用ID
|
||||
for call in tool_calls:
|
||||
if isinstance(call, dict) and "id" in call:
|
||||
tool_calls_ids.add(call["id"])
|
||||
elif hasattr(call, "id"):
|
||||
tool_calls_ids.add(call.id)
|
||||
|
||||
# 添加有效消息
|
||||
valid_messages.append(msg)
|
||||
|
||||
# 使用验证过的消息序列替换原来的消息
|
||||
formatted_messages = valid_messages
|
||||
|
||||
# Calculate input token count
|
||||
input_tokens = self.count_message_tokens(formatted_messages)
|
||||
|
||||
# If there are tools, calculate token count for tool descriptions
|
||||
tools_tokens = 0
|
||||
if tools:
|
||||
for tool in tools:
|
||||
tools_tokens += self.count_tokens(str(tool))
|
||||
|
||||
input_tokens += tools_tokens
|
||||
|
||||
# Check if token limits are exceeded
|
||||
if not self.check_token_limit(input_tokens):
|
||||
error_message = self.get_limit_error_message(input_tokens)
|
||||
# Raise a special exception that won't be retried
|
||||
raise TokenLimitExceeded(error_message)
|
||||
|
||||
# Validate tools if provided
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict) or "type" not in tool:
|
||||
raise ValueError("Each tool must be a dict with 'type' field")
|
||||
|
||||
# Set up the completion request
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.model in REASONING_MODELS:
|
||||
params["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
params["temperature"] = (
|
||||
temperature if temperature is not None else self.temperature
|
||||
)
|
||||
|
||||
response: ChatCompletion = await self.client.chat.completions.create(
|
||||
**params, stream=False
|
||||
)
|
||||
|
||||
# Check if response is valid
|
||||
if not response.choices or not response.choices[0].message:
|
||||
print(response)
|
||||
# raise ValueError("Invalid or empty response from LLM")
|
||||
return None
|
||||
|
||||
# Update token counts
|
||||
self.update_token_count(
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||||
)
|
||||
|
||||
return response.choices[0].message
|
||||
|
||||
except TokenLimitExceeded:
|
||||
# Re-raise token limit errors without logging
|
||||
raise
|
||||
except ValueError as ve:
|
||||
logger.error(f"Validation error in ask_tool: {ve}")
|
||||
raise
|
||||
except OpenAIError as oe:
|
||||
logger.error(f"OpenAI API error: {oe}")
|
||||
if isinstance(oe, AuthenticationError):
|
||||
logger.error("Authentication failed. Check API key.")
|
||||
elif isinstance(oe, RateLimitError):
|
||||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||||
elif isinstance(oe, APIError):
|
||||
logger.error(f"API error: {oe}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in ask_tool: {e}")
|
||||
raise
|
||||
42
MeetSpot/app/logger.py
Normal file
42
MeetSpot/app/logger.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from loguru import logger as _logger
|
||||
|
||||
from app.config import PROJECT_ROOT
|
||||
|
||||
|
||||
_print_level = "INFO"
|
||||
|
||||
|
||||
def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None):
|
||||
"""Adjust the log level to above level"""
|
||||
global _print_level
|
||||
_print_level = print_level
|
||||
|
||||
current_date = datetime.now()
|
||||
formatted_date = current_date.strftime("%Y%m%d%H%M%S")
|
||||
log_name = (
|
||||
f"{name}_{formatted_date}" if name else formatted_date
|
||||
) # name a log with prefix name
|
||||
|
||||
_logger.remove()
|
||||
_logger.add(sys.stderr, level=print_level)
|
||||
_logger.add(PROJECT_ROOT / f"logs/{log_name}.log", level=logfile_level)
|
||||
return _logger
|
||||
|
||||
|
||||
logger = define_log_level()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting application")
|
||||
logger.debug("Debug message")
|
||||
logger.warning("Warning message")
|
||||
logger.error("Error message")
|
||||
logger.critical("Critical message")
|
||||
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except Exception as e:
|
||||
logger.exception(f"An error occurred: {e}")
|
||||
9
MeetSpot/app/models/__init__.py
Normal file
9
MeetSpot/app/models/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""数据模型包。"""
|
||||
|
||||
from app.db.database import Base # noqa: F401
|
||||
|
||||
# 导入所有模型以确保它们注册到Base.metadata
|
||||
from app.models.user import User # noqa: F401
|
||||
from app.models.room import GatheringRoom, RoomParticipant # noqa: F401
|
||||
from app.models.message import ChatMessage, VenueVote # noqa: F401
|
||||
|
||||
53
MeetSpot/app/models/message.py
Normal file
53
MeetSpot/app/models/message.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""聊天消息与投票记录模型。"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, String, Text, UniqueConstraint, func
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
def _generate_uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class VenueVote(Base):
|
||||
__tablename__ = "venue_votes"
|
||||
__table_args__ = (UniqueConstraint("room_id", "venue_id", "user_id", name="uq_vote"),)
|
||||
|
||||
id = Column(String(36), primary_key=True, default=_generate_uuid)
|
||||
room_id = Column(String(36), ForeignKey("gathering_rooms.id"), nullable=False)
|
||||
venue_id = Column(String(100), nullable=False)
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
vote_type = Column(String(20), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_messages"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=_generate_uuid)
|
||||
room_id = Column(String(36), ForeignKey("gathering_rooms.id"), nullable=False)
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class ChatMessageCreate(BaseModel):
|
||||
content: str = Field(..., min_length=1, description="聊天内容")
|
||||
|
||||
|
||||
class VoteCreate(BaseModel):
|
||||
venue_id: str
|
||||
vote_type: str = Field(..., pattern="^(like|dislike)$")
|
||||
|
||||
|
||||
class VoteRead(BaseModel):
|
||||
venue_id: str
|
||||
vote_type: str
|
||||
user_id: str
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
60
MeetSpot/app/models/room.py
Normal file
60
MeetSpot/app/models/room.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""聚会房间相关的ORM与Pydantic模型。"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Column, DateTime, Float, ForeignKey, String, Text, UniqueConstraint, func
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
def _generate_uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class GatheringRoom(Base):
|
||||
__tablename__ = "gathering_rooms"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=_generate_uuid)
|
||||
name = Column(String(100), nullable=False)
|
||||
description = Column(Text, default="")
|
||||
host_user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
gathering_time = Column(DateTime(timezone=True))
|
||||
status = Column(String(20), default="pending")
|
||||
venue_keywords = Column(String(100), default="咖啡馆")
|
||||
final_venue_json = Column(Text, nullable=True)
|
||||
|
||||
|
||||
class RoomParticipant(Base):
|
||||
__tablename__ = "room_participants"
|
||||
__table_args__ = (UniqueConstraint("room_id", "user_id", name="uq_room_user"),)
|
||||
|
||||
id = Column(String(36), primary_key=True, default=_generate_uuid)
|
||||
room_id = Column(String(36), ForeignKey("gathering_rooms.id"), nullable=False)
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
location_name = Column(String(200))
|
||||
location_lat = Column(Float)
|
||||
location_lng = Column(Float)
|
||||
joined_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
role = Column(String(20), default="member")
|
||||
|
||||
|
||||
class GatheringRoomCreate(BaseModel):
|
||||
name: str = Field(..., description="聚会名称")
|
||||
description: str = Field("", description="聚会描述")
|
||||
gathering_time: Optional[datetime] = Field(
|
||||
None, description="聚会时间,ISO 字符串"
|
||||
)
|
||||
venue_keywords: str = Field("咖啡馆", description="场所类型关键词")
|
||||
|
||||
|
||||
class RoomParticipantRead(BaseModel):
|
||||
user_id: str
|
||||
nickname: str
|
||||
location_name: Optional[str] = None
|
||||
location_coords: Optional[Tuple[float, float]] = None
|
||||
role: str
|
||||
|
||||
44
MeetSpot/app/models/user.py
Normal file
44
MeetSpot/app/models/user.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""用户相关SQLAlchemy模型与Pydantic模式。"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Column, DateTime, String, func
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
def _generate_uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=_generate_uuid)
|
||||
phone = Column(String(20), unique=True, nullable=False)
|
||||
nickname = Column(String(50), nullable=False)
|
||||
avatar_url = Column(String(255), default="")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
last_login = Column(DateTime(timezone=True))
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
phone: str = Field(..., description="手机号")
|
||||
nickname: Optional[str] = Field(None, description="昵称,可选")
|
||||
avatar_url: Optional[str] = Field("", description="头像URL,可选")
|
||||
|
||||
|
||||
class UserRead(BaseModel):
|
||||
id: str
|
||||
phone: str
|
||||
nickname: str
|
||||
avatar_url: str = ""
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
214
MeetSpot/app/schema.py
Normal file
214
MeetSpot/app/schema.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""Message role options"""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
ROLE_VALUES = tuple(role.value for role in Role)
|
||||
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
|
||||
|
||||
|
||||
class ToolChoice(str, Enum):
|
||||
"""Tool choice options"""
|
||||
|
||||
NONE = "none"
|
||||
AUTO = "auto"
|
||||
REQUIRED = "required"
|
||||
|
||||
|
||||
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
|
||||
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
|
||||
|
||||
|
||||
class AgentState(str, Enum):
|
||||
"""Agent execution states"""
|
||||
|
||||
IDLE = "IDLE"
|
||||
RUNNING = "RUNNING"
|
||||
FINISHED = "FINISHED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Represents a tool/function call in a message"""
|
||||
|
||||
id: str
|
||||
type: str = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Represents a chat message in the conversation"""
|
||||
|
||||
role: ROLE_TYPE = Field(...) # type: ignore
|
||||
content: Optional[str] = Field(default=None)
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||
name: Optional[str] = Field(default=None)
|
||||
tool_call_id: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
|
||||
def __add__(self, other) -> List["Message"]:
|
||||
"""支持 Message + list 或 Message + Message 的操作"""
|
||||
if isinstance(other, list):
|
||||
return [self] + other
|
||||
elif isinstance(other, Message):
|
||||
return [self, other]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
|
||||
)
|
||||
|
||||
def __radd__(self, other) -> List["Message"]:
|
||||
"""支持 list + Message 的操作"""
|
||||
if isinstance(other, list):
|
||||
return other + [self]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert message to dictionary format"""
|
||||
message = {"role": self.role}
|
||||
if self.content is not None:
|
||||
message["content"] = self.content
|
||||
if self.tool_calls is not None and self.role == Role.ASSISTANT:
|
||||
message["tool_calls"] = [
|
||||
tool_call.dict() if hasattr(tool_call, "dict") else tool_call
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
if self.name is not None and self.role == Role.TOOL:
|
||||
message["name"] = self.name
|
||||
if self.tool_call_id is not None and self.role == Role.TOOL:
|
||||
message["tool_call_id"] = self.tool_call_id
|
||||
# 不要在API调用中包含base64_image,这不是OpenAI API消息格式的一部分
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def user_message(
|
||||
cls, content: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create a user message"""
|
||||
return cls(role=Role.USER, content=content, base64_image=base64_image)
|
||||
|
||||
@classmethod
|
||||
def system_message(cls, content: str) -> "Message":
|
||||
"""Create a system message"""
|
||||
return cls(role=Role.SYSTEM, content=content)
|
||||
|
||||
@classmethod
|
||||
def assistant_message(
|
||||
cls, content: Optional[str] = None, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create an assistant message"""
|
||||
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
|
||||
|
||||
@classmethod
|
||||
def tool_message(
|
||||
cls, content: str, name: str, tool_call_id: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create a tool message
|
||||
|
||||
Args:
|
||||
content: The content/result of the tool execution
|
||||
name: The name of the tool that was executed
|
||||
tool_call_id: The ID of the tool call this message is responding to
|
||||
base64_image: Optional base64 encoded image
|
||||
"""
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required for tool messages")
|
||||
if not name:
|
||||
raise ValueError("name is required for tool messages")
|
||||
|
||||
return cls(
|
||||
role=Role.TOOL,
|
||||
content=content,
|
||||
name=name,
|
||||
tool_call_id=tool_call_id,
|
||||
base64_image=base64_image,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_calls(
|
||||
cls,
|
||||
tool_calls: List[Any],
|
||||
content: Union[str, List[str]] = "",
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create ToolCallsMessage from raw tool calls.
|
||||
|
||||
Args:
|
||||
tool_calls: Raw tool calls from LLM
|
||||
content: Optional message content
|
||||
base64_image: Optional base64 encoded image
|
||||
"""
|
||||
# 确保tool_calls是正确格式的对象列表
|
||||
formatted_calls = []
|
||||
for call in tool_calls:
|
||||
if hasattr(call, "id") and hasattr(call, "function"):
|
||||
func_data = call.function
|
||||
if hasattr(func_data, "model_dump"):
|
||||
func_dict = func_data.model_dump()
|
||||
else:
|
||||
func_dict = {"name": func_data.name, "arguments": func_data.arguments}
|
||||
|
||||
formatted_call = {
|
||||
"id": call.id,
|
||||
"type": "function",
|
||||
"function": func_dict
|
||||
}
|
||||
formatted_calls.append(formatted_call)
|
||||
else:
|
||||
# 如果已经是字典格式,直接使用
|
||||
formatted_calls.append(call)
|
||||
|
||||
return cls(
|
||||
role=Role.ASSISTANT,
|
||||
content=content,
|
||||
tool_calls=formatted_calls,
|
||||
base64_image=base64_image,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
max_messages: int = Field(default=100)
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""Add a message to memory"""
|
||||
self.messages.append(message)
|
||||
# Optional: Implement message limit
|
||||
if len(self.messages) > self.max_messages:
|
||||
self.messages = self.messages[-self.max_messages :]
|
||||
|
||||
def add_messages(self, messages: List[Message]) -> None:
|
||||
"""Add multiple messages to memory"""
|
||||
self.messages.extend(messages)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages"""
|
||||
self.messages.clear()
|
||||
|
||||
def get_recent_messages(self, n: int) -> List[Message]:
|
||||
"""Get n most recent messages"""
|
||||
return self.messages[-n:]
|
||||
|
||||
def to_dict_list(self) -> List[dict]:
|
||||
"""Convert messages to list of dicts"""
|
||||
return [msg.to_dict() for msg in self.messages]
|
||||
9
MeetSpot/app/tool/__init__.py
Normal file
9
MeetSpot/app/tool/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.tool.tool_collection import ToolCollection
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"CafeRecommender",
|
||||
"ToolCollection",
|
||||
]
|
||||
101
MeetSpot/app/tool/base.py
Normal file
101
MeetSpot/app/tool/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Execute the tool with given parameters."""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> Any:
|
||||
"""Execute the tool with given parameters."""
|
||||
|
||||
def to_param(self) -> Dict:
|
||||
"""Convert tool to function call format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: Any = Field(default=None)
|
||||
error: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
system: Optional[str] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field) for field in self.__fields__)
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(
|
||||
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||
):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"Error: {self.error}" if self.error else self.output
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
# return self.copy(update=kwargs)
|
||||
return type(self)(**{**self.dict(), **kwargs})
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
# 为 BaseTool 添加辅助方法
|
||||
def _success_response(data) -> ToolResult:
|
||||
"""创建成功的工具结果"""
|
||||
import json
|
||||
if isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = json.dumps(data, ensure_ascii=False, indent=2)
|
||||
return ToolResult(output=text)
|
||||
|
||||
|
||||
def _fail_response(msg: str) -> ToolResult:
|
||||
"""创建失败的工具结果"""
|
||||
return ToolResult(error=msg)
|
||||
|
||||
|
||||
# 将辅助方法添加到 BaseTool
|
||||
BaseTool.success_response = staticmethod(_success_response)
|
||||
BaseTool.fail_response = staticmethod(_fail_response)
|
||||
158
MeetSpot/app/tool/file_operators.py
Normal file
158
MeetSpot/app/tool/file_operators.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""File operation interfaces and implementations for local and sandbox environments."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Protocol, Tuple, Union, runtime_checkable
|
||||
|
||||
from app.config import SandboxSettings
|
||||
from app.exceptions import ToolError
|
||||
from app.sandbox.client import SANDBOX_CLIENT
|
||||
|
||||
|
||||
PathLike = Union[str, Path]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FileOperator(Protocol):
|
||||
"""Interface for file operations in different environments."""
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a file."""
|
||||
...
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a file."""
|
||||
...
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory."""
|
||||
...
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists."""
|
||||
...
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a shell command and return (return_code, stdout, stderr)."""
|
||||
...
|
||||
|
||||
|
||||
class LocalFileOperator(FileOperator):
|
||||
"""File operations implementation for local filesystem."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a local file."""
|
||||
try:
|
||||
return Path(path).read_text(encoding=self.encoding)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path}: {str(e)}") from None
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a local file."""
|
||||
try:
|
||||
Path(path).write_text(content, encoding=self.encoding)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path}: {str(e)}") from None
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory."""
|
||||
return Path(path).is_dir()
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists."""
|
||||
return Path(path).exists()
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a shell command locally."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
stdout.decode(),
|
||||
stderr.decode(),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
|
||||
|
||||
class SandboxFileOperator(FileOperator):
|
||||
"""File operations implementation for sandbox environment."""
|
||||
|
||||
def __init__(self):
|
||||
self.sandbox_client = SANDBOX_CLIENT
|
||||
|
||||
async def _ensure_sandbox_initialized(self):
|
||||
"""Ensure sandbox is initialized."""
|
||||
if not self.sandbox_client.sandbox:
|
||||
await self.sandbox_client.create(config=SandboxSettings())
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a file in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
return await self.sandbox_client.read_file(str(path))
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path} in sandbox: {str(e)}") from None
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a file in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
await self.sandbox_client.write_file(str(path), content)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path} in sandbox: {str(e)}") from None
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
result = await self.sandbox_client.run_command(
|
||||
f"test -d {path} && echo 'true' || echo 'false'"
|
||||
)
|
||||
return result.strip() == "true"
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
result = await self.sandbox_client.run_command(
|
||||
f"test -e {path} && echo 'true' || echo 'false'"
|
||||
)
|
||||
return result.strip() == "true"
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a command in sandbox environment."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
stdout = await self.sandbox_client.run_command(
|
||||
cmd, timeout=int(timeout) if timeout else None
|
||||
)
|
||||
return (
|
||||
0, # Always return 0 since we don't have explicit return code from sandbox
|
||||
stdout,
|
||||
"", # No stderr capture in the current sandbox implementation
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds in sandbox"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
return 1, "", f"Error executing command in sandbox: {str(exc)}"
|
||||
3615
MeetSpot/app/tool/meetspot_recommender.py
Normal file
3615
MeetSpot/app/tool/meetspot_recommender.py
Normal file
File diff suppressed because it is too large
Load Diff
64
MeetSpot/app/tool/tool_collection.py
Normal file
64
MeetSpot/app/tool/tool_collection.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from app.exceptions import ToolError
|
||||
from app.tool.base import BaseTool, ToolFailure, ToolResult
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of defined tools."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, *tools: BaseTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
def to_params(self) -> List[Dict[str, Any]]:
|
||||
return [tool.to_param() for tool in self.tools]
|
||||
|
||||
async def execute(self, name: str, tool_input: Union[str, dict]) -> ToolResult:
|
||||
"""Execute a tool by name with given input."""
|
||||
tool = self.get_tool(name)
|
||||
if not tool:
|
||||
return ToolResult(error=f"Tool '{name}' not found")
|
||||
|
||||
# 确保 tool_input 是字典类型
|
||||
if isinstance(tool_input, str):
|
||||
try:
|
||||
tool_input = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return ToolResult(error=f"Invalid tool input format: {tool_input}")
|
||||
|
||||
result = await tool(**tool_input)
|
||||
return result
|
||||
|
||||
async def execute_all(self) -> List[ToolResult]:
|
||||
"""Execute all tools in the collection sequentially."""
|
||||
results = []
|
||||
for tool in self.tools:
|
||||
try:
|
||||
result = await tool()
|
||||
results.append(result)
|
||||
except ToolError as e:
|
||||
results.append(ToolFailure(error=e.message))
|
||||
return results
|
||||
|
||||
def get_tool(self, name: str) -> BaseTool:
|
||||
return self.tool_map.get(name)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
self.tools += (tool,)
|
||||
self.tool_map[tool.name] = tool
|
||||
return self
|
||||
|
||||
def add_tools(self, *tools: BaseTool):
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
return self
|
||||
101
MeetSpot/app/tool/web_search.py
Normal file
101
MeetSpot/app/tool/web_search.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.config import config
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.search import (
|
||||
BaiduSearchEngine,
|
||||
BingSearchEngine,
|
||||
DuckDuckGoSearchEngine,
|
||||
GoogleSearchEngine,
|
||||
WebSearchEngine,
|
||||
)
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
name: str = "web_search"
|
||||
description: str = """Perform a web search and return a list of relevant links.
|
||||
This function attempts to use the primary search engine API to get up-to-date results.
|
||||
If an error occurs, it falls back to an alternative search engine."""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "(required) The search query to submit to the search engine.",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "(optional) The number of search results to return. Default is 10.",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
_search_engine: dict[str, WebSearchEngine] = {
|
||||
"google": GoogleSearchEngine(),
|
||||
"baidu": BaiduSearchEngine(),
|
||||
"duckduckgo": DuckDuckGoSearchEngine(),
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, query: str, num_results: int = 10) -> List[str]:
|
||||
"""
|
||||
Execute a Web search and return a list of URLs.
|
||||
|
||||
Args:
|
||||
query (str): The search query to submit to the search engine.
|
||||
num_results (int, optional): The number of search results to return. Default is 10.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of URLs matching the search query.
|
||||
"""
|
||||
engine_order = self._get_engine_order()
|
||||
for engine_name in engine_order:
|
||||
engine = self._search_engine[engine_name]
|
||||
try:
|
||||
links = await self._perform_search_with_engine(
|
||||
engine, query, num_results
|
||||
)
|
||||
if links:
|
||||
return links
|
||||
except Exception as e:
|
||||
print(f"Search engine '{engine_name}' failed with error: {e}")
|
||||
return []
|
||||
|
||||
def _get_engine_order(self) -> List[str]:
|
||||
"""
|
||||
Determines the order in which to try search engines.
|
||||
Preferred engine is first (based on configuration), followed by the remaining engines.
|
||||
|
||||
Returns:
|
||||
List[str]: Ordered list of search engine names.
|
||||
"""
|
||||
preferred = "google"
|
||||
if config.search_config and config.search_config.engine:
|
||||
preferred = config.search_config.engine.lower()
|
||||
|
||||
engine_order = []
|
||||
if preferred in self._search_engine:
|
||||
engine_order.append(preferred)
|
||||
for key in self._search_engine:
|
||||
if key not in engine_order:
|
||||
engine_order.append(key)
|
||||
return engine_order
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
)
|
||||
async def _perform_search_with_engine(
|
||||
self,
|
||||
engine: WebSearchEngine,
|
||||
query: str,
|
||||
num_results: int,
|
||||
) -> List[str]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, lambda: list(engine.perform_search(query, num_results=num_results))
|
||||
)
|
||||
Reference in New Issue
Block a user