from enum import Enum from typing import Any, List, Literal, Optional, Union from pydantic import BaseModel, Field class Role(str, Enum): """Message role options""" SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" ROLE_VALUES = tuple(role.value for role in Role) ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore class ToolChoice(str, Enum): """Tool choice options""" NONE = "none" AUTO = "auto" REQUIRED = "required" TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice) TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore class AgentState(str, Enum): """Agent execution states""" IDLE = "IDLE" RUNNING = "RUNNING" FINISHED = "FINISHED" ERROR = "ERROR" class Function(BaseModel): name: str arguments: str class ToolCall(BaseModel): """Represents a tool/function call in a message""" id: str type: str = "function" function: Function class Message(BaseModel): """Represents a chat message in the conversation""" role: ROLE_TYPE = Field(...) # type: ignore content: Optional[str] = Field(default=None) tool_calls: Optional[List[ToolCall]] = Field(default=None) name: Optional[str] = Field(default=None) tool_call_id: Optional[str] = Field(default=None) base64_image: Optional[str] = Field(default=None) def __add__(self, other) -> List["Message"]: """支持 Message + list 或 Message + Message 的操作""" if isinstance(other, list): return [self] + other elif isinstance(other, Message): return [self, other] else: raise TypeError( f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" ) def __radd__(self, other) -> List["Message"]: """支持 list + Message 的操作""" if isinstance(other, list): return other + [self] else: raise TypeError( f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" ) def to_dict(self) -> dict: """Convert message to dictionary format""" message = {"role": self.role} if self.content is not None: message["content"] = self.content if self.tool_calls is not None and self.role == Role.ASSISTANT: message["tool_calls"] = [ tool_call.dict() if hasattr(tool_call, "dict") else tool_call for tool_call in self.tool_calls ] if self.name is not None and self.role == Role.TOOL: message["name"] = self.name if self.tool_call_id is not None and self.role == Role.TOOL: message["tool_call_id"] = self.tool_call_id # 不要在API调用中包含base64_image,这不是OpenAI API消息格式的一部分 return message @classmethod def user_message( cls, content: str, base64_image: Optional[str] = None ) -> "Message": """Create a user message""" return cls(role=Role.USER, content=content, base64_image=base64_image) @classmethod def system_message(cls, content: str) -> "Message": """Create a system message""" return cls(role=Role.SYSTEM, content=content) @classmethod def assistant_message( cls, content: Optional[str] = None, base64_image: Optional[str] = None ) -> "Message": """Create an assistant message""" return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image) @classmethod def tool_message( cls, content: str, name: str, tool_call_id: str, base64_image: Optional[str] = None ) -> "Message": """Create a tool message Args: content: The content/result of the tool execution name: The name of the tool that was executed tool_call_id: The ID of the tool call this message is responding to base64_image: Optional base64 encoded image """ if not tool_call_id: raise ValueError("tool_call_id is required for tool messages") if not name: raise ValueError("name is required for tool messages") return cls( role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id, base64_image=base64_image, ) @classmethod def from_tool_calls( cls, tool_calls: List[Any], content: Union[str, List[str]] = "", base64_image: Optional[str] = None, **kwargs, ): """Create ToolCallsMessage from raw tool calls. Args: tool_calls: Raw tool calls from LLM content: Optional message content base64_image: Optional base64 encoded image """ # 确保tool_calls是正确格式的对象列表 formatted_calls = [] for call in tool_calls: if hasattr(call, "id") and hasattr(call, "function"): func_data = call.function if hasattr(func_data, "model_dump"): func_dict = func_data.model_dump() else: func_dict = {"name": func_data.name, "arguments": func_data.arguments} formatted_call = { "id": call.id, "type": "function", "function": func_dict } formatted_calls.append(formatted_call) else: # 如果已经是字典格式,直接使用 formatted_calls.append(call) return cls( role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, base64_image=base64_image, **kwargs, ) class Memory(BaseModel): messages: List[Message] = Field(default_factory=list) max_messages: int = Field(default=100) def add_message(self, message: Message) -> None: """Add a message to memory""" self.messages.append(message) # Optional: Implement message limit if len(self.messages) > self.max_messages: self.messages = self.messages[-self.max_messages :] def add_messages(self, messages: List[Message]) -> None: """Add multiple messages to memory""" self.messages.extend(messages) def clear(self) -> None: """Clear all messages""" self.messages.clear() def get_recent_messages(self, n: int) -> List[Message]: """Get n most recent messages""" return self.messages[-n:] def to_dict_list(self) -> List[dict]: """Convert messages to list of dicts""" return [msg.to_dict() for msg in self.messages]