Files
WoMenQuNaJu/MeetSpot/app/tool/tool_collection.py
2026-02-04 16:11:55 +08:00

65 lines
1.9 KiB
Python

"""Collection classes for managing multiple tools."""
import json
from typing import Any, Dict, List, Union
from app.exceptions import ToolError
from app.tool.base import BaseTool, ToolFailure, ToolResult
class ToolCollection:
"""A collection of defined tools."""
class Config:
arbitrary_types_allowed = True
def __init__(self, *tools: BaseTool):
self.tools = tools
self.tool_map = {tool.name: tool for tool in tools}
def __iter__(self):
return iter(self.tools)
def to_params(self) -> List[Dict[str, Any]]:
return [tool.to_param() for tool in self.tools]
async def execute(self, name: str, tool_input: Union[str, dict]) -> ToolResult:
"""Execute a tool by name with given input."""
tool = self.get_tool(name)
if not tool:
return ToolResult(error=f"Tool '{name}' not found")
# 确保 tool_input 是字典类型
if isinstance(tool_input, str):
try:
tool_input = json.loads(tool_input)
except json.JSONDecodeError:
return ToolResult(error=f"Invalid tool input format: {tool_input}")
result = await tool(**tool_input)
return result
async def execute_all(self) -> List[ToolResult]:
"""Execute all tools in the collection sequentially."""
results = []
for tool in self.tools:
try:
result = await tool()
results.append(result)
except ToolError as e:
results.append(ToolFailure(error=e.message))
return results
def get_tool(self, name: str) -> BaseTool:
return self.tool_map.get(name)
def add_tool(self, tool: BaseTool):
self.tools += (tool,)
self.tool_map[tool.name] = tool
return self
def add_tools(self, *tools: BaseTool):
for tool in tools:
self.add_tool(tool)
return self