first commit
This commit is contained in:
64
MeetSpot/app/tool/tool_collection.py
Normal file
64
MeetSpot/app/tool/tool_collection.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from app.exceptions import ToolError
|
||||
from app.tool.base import BaseTool, ToolFailure, ToolResult
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of defined tools."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, *tools: BaseTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
def to_params(self) -> List[Dict[str, Any]]:
|
||||
return [tool.to_param() for tool in self.tools]
|
||||
|
||||
async def execute(self, name: str, tool_input: Union[str, dict]) -> ToolResult:
|
||||
"""Execute a tool by name with given input."""
|
||||
tool = self.get_tool(name)
|
||||
if not tool:
|
||||
return ToolResult(error=f"Tool '{name}' not found")
|
||||
|
||||
# 确保 tool_input 是字典类型
|
||||
if isinstance(tool_input, str):
|
||||
try:
|
||||
tool_input = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return ToolResult(error=f"Invalid tool input format: {tool_input}")
|
||||
|
||||
result = await tool(**tool_input)
|
||||
return result
|
||||
|
||||
async def execute_all(self) -> List[ToolResult]:
|
||||
"""Execute all tools in the collection sequentially."""
|
||||
results = []
|
||||
for tool in self.tools:
|
||||
try:
|
||||
result = await tool()
|
||||
results.append(result)
|
||||
except ToolError as e:
|
||||
results.append(ToolFailure(error=e.message))
|
||||
return results
|
||||
|
||||
def get_tool(self, name: str) -> BaseTool:
|
||||
return self.tool_map.get(name)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
self.tools += (tool,)
|
||||
self.tool_map[tool.name] = tool
|
||||
return self
|
||||
|
||||
def add_tools(self, *tools: BaseTool):
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
return self
|
||||
Reference in New Issue
Block a user