first commit
This commit is contained in:
214
MeetSpot/app/schema.py
Normal file
214
MeetSpot/app/schema.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""Message role options"""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
ROLE_VALUES = tuple(role.value for role in Role)
|
||||
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
|
||||
|
||||
|
||||
class ToolChoice(str, Enum):
|
||||
"""Tool choice options"""
|
||||
|
||||
NONE = "none"
|
||||
AUTO = "auto"
|
||||
REQUIRED = "required"
|
||||
|
||||
|
||||
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
|
||||
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
|
||||
|
||||
|
||||
class AgentState(str, Enum):
|
||||
"""Agent execution states"""
|
||||
|
||||
IDLE = "IDLE"
|
||||
RUNNING = "RUNNING"
|
||||
FINISHED = "FINISHED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Represents a tool/function call in a message"""
|
||||
|
||||
id: str
|
||||
type: str = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Represents a chat message in the conversation"""
|
||||
|
||||
role: ROLE_TYPE = Field(...) # type: ignore
|
||||
content: Optional[str] = Field(default=None)
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||
name: Optional[str] = Field(default=None)
|
||||
tool_call_id: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
|
||||
def __add__(self, other) -> List["Message"]:
|
||||
"""支持 Message + list 或 Message + Message 的操作"""
|
||||
if isinstance(other, list):
|
||||
return [self] + other
|
||||
elif isinstance(other, Message):
|
||||
return [self, other]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
|
||||
)
|
||||
|
||||
def __radd__(self, other) -> List["Message"]:
|
||||
"""支持 list + Message 的操作"""
|
||||
if isinstance(other, list):
|
||||
return other + [self]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert message to dictionary format"""
|
||||
message = {"role": self.role}
|
||||
if self.content is not None:
|
||||
message["content"] = self.content
|
||||
if self.tool_calls is not None and self.role == Role.ASSISTANT:
|
||||
message["tool_calls"] = [
|
||||
tool_call.dict() if hasattr(tool_call, "dict") else tool_call
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
if self.name is not None and self.role == Role.TOOL:
|
||||
message["name"] = self.name
|
||||
if self.tool_call_id is not None and self.role == Role.TOOL:
|
||||
message["tool_call_id"] = self.tool_call_id
|
||||
# 不要在API调用中包含base64_image,这不是OpenAI API消息格式的一部分
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def user_message(
|
||||
cls, content: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create a user message"""
|
||||
return cls(role=Role.USER, content=content, base64_image=base64_image)
|
||||
|
||||
@classmethod
|
||||
def system_message(cls, content: str) -> "Message":
|
||||
"""Create a system message"""
|
||||
return cls(role=Role.SYSTEM, content=content)
|
||||
|
||||
@classmethod
|
||||
def assistant_message(
|
||||
cls, content: Optional[str] = None, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create an assistant message"""
|
||||
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
|
||||
|
||||
@classmethod
|
||||
def tool_message(
|
||||
cls, content: str, name: str, tool_call_id: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create a tool message
|
||||
|
||||
Args:
|
||||
content: The content/result of the tool execution
|
||||
name: The name of the tool that was executed
|
||||
tool_call_id: The ID of the tool call this message is responding to
|
||||
base64_image: Optional base64 encoded image
|
||||
"""
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required for tool messages")
|
||||
if not name:
|
||||
raise ValueError("name is required for tool messages")
|
||||
|
||||
return cls(
|
||||
role=Role.TOOL,
|
||||
content=content,
|
||||
name=name,
|
||||
tool_call_id=tool_call_id,
|
||||
base64_image=base64_image,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_calls(
|
||||
cls,
|
||||
tool_calls: List[Any],
|
||||
content: Union[str, List[str]] = "",
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create ToolCallsMessage from raw tool calls.
|
||||
|
||||
Args:
|
||||
tool_calls: Raw tool calls from LLM
|
||||
content: Optional message content
|
||||
base64_image: Optional base64 encoded image
|
||||
"""
|
||||
# 确保tool_calls是正确格式的对象列表
|
||||
formatted_calls = []
|
||||
for call in tool_calls:
|
||||
if hasattr(call, "id") and hasattr(call, "function"):
|
||||
func_data = call.function
|
||||
if hasattr(func_data, "model_dump"):
|
||||
func_dict = func_data.model_dump()
|
||||
else:
|
||||
func_dict = {"name": func_data.name, "arguments": func_data.arguments}
|
||||
|
||||
formatted_call = {
|
||||
"id": call.id,
|
||||
"type": "function",
|
||||
"function": func_dict
|
||||
}
|
||||
formatted_calls.append(formatted_call)
|
||||
else:
|
||||
# 如果已经是字典格式,直接使用
|
||||
formatted_calls.append(call)
|
||||
|
||||
return cls(
|
||||
role=Role.ASSISTANT,
|
||||
content=content,
|
||||
tool_calls=formatted_calls,
|
||||
base64_image=base64_image,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
max_messages: int = Field(default=100)
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""Add a message to memory"""
|
||||
self.messages.append(message)
|
||||
# Optional: Implement message limit
|
||||
if len(self.messages) > self.max_messages:
|
||||
self.messages = self.messages[-self.max_messages :]
|
||||
|
||||
def add_messages(self, messages: List[Message]) -> None:
|
||||
"""Add multiple messages to memory"""
|
||||
self.messages.extend(messages)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages"""
|
||||
self.messages.clear()
|
||||
|
||||
def get_recent_messages(self, n: int) -> List[Message]:
|
||||
"""Get n most recent messages"""
|
||||
return self.messages[-n:]
|
||||
|
||||
def to_dict_list(self) -> List[dict]:
|
||||
"""Convert messages to list of dicts"""
|
||||
return [msg.to_dict() for msg in self.messages]
|
||||
Reference in New Issue
Block a user