first commit
This commit is contained in:
101
MeetSpot/app/tool/base.py
Normal file
101
MeetSpot/app/tool/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Execute the tool with given parameters."""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> Any:
|
||||
"""Execute the tool with given parameters."""
|
||||
|
||||
def to_param(self) -> Dict:
|
||||
"""Convert tool to function call format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: Any = Field(default=None)
|
||||
error: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
system: Optional[str] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field) for field in self.__fields__)
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(
|
||||
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||
):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"Error: {self.error}" if self.error else self.output
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
# return self.copy(update=kwargs)
|
||||
return type(self)(**{**self.dict(), **kwargs})
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
# 为 BaseTool 添加辅助方法
|
||||
def _success_response(data) -> ToolResult:
|
||||
"""创建成功的工具结果"""
|
||||
import json
|
||||
if isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = json.dumps(data, ensure_ascii=False, indent=2)
|
||||
return ToolResult(output=text)
|
||||
|
||||
|
||||
def _fail_response(msg: str) -> ToolResult:
|
||||
"""创建失败的工具结果"""
|
||||
return ToolResult(error=msg)
|
||||
|
||||
|
||||
# 将辅助方法添加到 BaseTool
|
||||
BaseTool.success_response = staticmethod(_success_response)
|
||||
BaseTool.fail_response = staticmethod(_fail_response)
|
||||
Reference in New Issue
Block a user