215 lines
6.7 KiB
Python
215 lines
6.7 KiB
Python
from enum import Enum
|
||
from typing import Any, List, Literal, Optional, Union
|
||
|
||
from pydantic import BaseModel, Field
|
||
|
||
|
||
class Role(str, Enum):
|
||
"""Message role options"""
|
||
|
||
SYSTEM = "system"
|
||
USER = "user"
|
||
ASSISTANT = "assistant"
|
||
TOOL = "tool"
|
||
|
||
|
||
ROLE_VALUES = tuple(role.value for role in Role)
|
||
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
|
||
|
||
|
||
class ToolChoice(str, Enum):
|
||
"""Tool choice options"""
|
||
|
||
NONE = "none"
|
||
AUTO = "auto"
|
||
REQUIRED = "required"
|
||
|
||
|
||
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
|
||
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
|
||
|
||
|
||
class AgentState(str, Enum):
|
||
"""Agent execution states"""
|
||
|
||
IDLE = "IDLE"
|
||
RUNNING = "RUNNING"
|
||
FINISHED = "FINISHED"
|
||
ERROR = "ERROR"
|
||
|
||
|
||
class Function(BaseModel):
|
||
name: str
|
||
arguments: str
|
||
|
||
|
||
class ToolCall(BaseModel):
|
||
"""Represents a tool/function call in a message"""
|
||
|
||
id: str
|
||
type: str = "function"
|
||
function: Function
|
||
|
||
|
||
class Message(BaseModel):
|
||
"""Represents a chat message in the conversation"""
|
||
|
||
role: ROLE_TYPE = Field(...) # type: ignore
|
||
content: Optional[str] = Field(default=None)
|
||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||
name: Optional[str] = Field(default=None)
|
||
tool_call_id: Optional[str] = Field(default=None)
|
||
base64_image: Optional[str] = Field(default=None)
|
||
|
||
def __add__(self, other) -> List["Message"]:
|
||
"""支持 Message + list 或 Message + Message 的操作"""
|
||
if isinstance(other, list):
|
||
return [self] + other
|
||
elif isinstance(other, Message):
|
||
return [self, other]
|
||
else:
|
||
raise TypeError(
|
||
f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
|
||
)
|
||
|
||
def __radd__(self, other) -> List["Message"]:
|
||
"""支持 list + Message 的操作"""
|
||
if isinstance(other, list):
|
||
return other + [self]
|
||
else:
|
||
raise TypeError(
|
||
f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
|
||
)
|
||
|
||
def to_dict(self) -> dict:
|
||
"""Convert message to dictionary format"""
|
||
message = {"role": self.role}
|
||
if self.content is not None:
|
||
message["content"] = self.content
|
||
if self.tool_calls is not None and self.role == Role.ASSISTANT:
|
||
message["tool_calls"] = [
|
||
tool_call.dict() if hasattr(tool_call, "dict") else tool_call
|
||
for tool_call in self.tool_calls
|
||
]
|
||
if self.name is not None and self.role == Role.TOOL:
|
||
message["name"] = self.name
|
||
if self.tool_call_id is not None and self.role == Role.TOOL:
|
||
message["tool_call_id"] = self.tool_call_id
|
||
# 不要在API调用中包含base64_image,这不是OpenAI API消息格式的一部分
|
||
return message
|
||
|
||
@classmethod
|
||
def user_message(
|
||
cls, content: str, base64_image: Optional[str] = None
|
||
) -> "Message":
|
||
"""Create a user message"""
|
||
return cls(role=Role.USER, content=content, base64_image=base64_image)
|
||
|
||
@classmethod
|
||
def system_message(cls, content: str) -> "Message":
|
||
"""Create a system message"""
|
||
return cls(role=Role.SYSTEM, content=content)
|
||
|
||
@classmethod
|
||
def assistant_message(
|
||
cls, content: Optional[str] = None, base64_image: Optional[str] = None
|
||
) -> "Message":
|
||
"""Create an assistant message"""
|
||
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
|
||
|
||
@classmethod
|
||
def tool_message(
|
||
cls, content: str, name: str, tool_call_id: str, base64_image: Optional[str] = None
|
||
) -> "Message":
|
||
"""Create a tool message
|
||
|
||
Args:
|
||
content: The content/result of the tool execution
|
||
name: The name of the tool that was executed
|
||
tool_call_id: The ID of the tool call this message is responding to
|
||
base64_image: Optional base64 encoded image
|
||
"""
|
||
if not tool_call_id:
|
||
raise ValueError("tool_call_id is required for tool messages")
|
||
if not name:
|
||
raise ValueError("name is required for tool messages")
|
||
|
||
return cls(
|
||
role=Role.TOOL,
|
||
content=content,
|
||
name=name,
|
||
tool_call_id=tool_call_id,
|
||
base64_image=base64_image,
|
||
)
|
||
|
||
@classmethod
|
||
def from_tool_calls(
|
||
cls,
|
||
tool_calls: List[Any],
|
||
content: Union[str, List[str]] = "",
|
||
base64_image: Optional[str] = None,
|
||
**kwargs,
|
||
):
|
||
"""Create ToolCallsMessage from raw tool calls.
|
||
|
||
Args:
|
||
tool_calls: Raw tool calls from LLM
|
||
content: Optional message content
|
||
base64_image: Optional base64 encoded image
|
||
"""
|
||
# 确保tool_calls是正确格式的对象列表
|
||
formatted_calls = []
|
||
for call in tool_calls:
|
||
if hasattr(call, "id") and hasattr(call, "function"):
|
||
func_data = call.function
|
||
if hasattr(func_data, "model_dump"):
|
||
func_dict = func_data.model_dump()
|
||
else:
|
||
func_dict = {"name": func_data.name, "arguments": func_data.arguments}
|
||
|
||
formatted_call = {
|
||
"id": call.id,
|
||
"type": "function",
|
||
"function": func_dict
|
||
}
|
||
formatted_calls.append(formatted_call)
|
||
else:
|
||
# 如果已经是字典格式,直接使用
|
||
formatted_calls.append(call)
|
||
|
||
return cls(
|
||
role=Role.ASSISTANT,
|
||
content=content,
|
||
tool_calls=formatted_calls,
|
||
base64_image=base64_image,
|
||
**kwargs,
|
||
)
|
||
|
||
|
||
class Memory(BaseModel):
|
||
messages: List[Message] = Field(default_factory=list)
|
||
max_messages: int = Field(default=100)
|
||
|
||
def add_message(self, message: Message) -> None:
|
||
"""Add a message to memory"""
|
||
self.messages.append(message)
|
||
# Optional: Implement message limit
|
||
if len(self.messages) > self.max_messages:
|
||
self.messages = self.messages[-self.max_messages :]
|
||
|
||
def add_messages(self, messages: List[Message]) -> None:
|
||
"""Add multiple messages to memory"""
|
||
self.messages.extend(messages)
|
||
|
||
def clear(self) -> None:
|
||
"""Clear all messages"""
|
||
self.messages.clear()
|
||
|
||
def get_recent_messages(self, n: int) -> List[Message]:
|
||
"""Get n most recent messages"""
|
||
return self.messages[-n:]
|
||
|
||
def to_dict_list(self) -> List[dict]:
|
||
"""Convert messages to list of dicts"""
|
||
return [msg.to_dict() for msg in self.messages]
|