first commit
This commit is contained in:
9
MeetSpot/app/tool/__init__.py
Normal file
9
MeetSpot/app/tool/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.meetspot_recommender import CafeRecommender
|
||||
from app.tool.tool_collection import ToolCollection
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"CafeRecommender",
|
||||
"ToolCollection",
|
||||
]
|
||||
101
MeetSpot/app/tool/base.py
Normal file
101
MeetSpot/app/tool/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Execute the tool with given parameters."""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> Any:
|
||||
"""Execute the tool with given parameters."""
|
||||
|
||||
def to_param(self) -> Dict:
|
||||
"""Convert tool to function call format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: Any = Field(default=None)
|
||||
error: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
system: Optional[str] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field) for field in self.__fields__)
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(
|
||||
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||
):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"Error: {self.error}" if self.error else self.output
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
# return self.copy(update=kwargs)
|
||||
return type(self)(**{**self.dict(), **kwargs})
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
# 为 BaseTool 添加辅助方法
|
||||
def _success_response(data) -> ToolResult:
|
||||
"""创建成功的工具结果"""
|
||||
import json
|
||||
if isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = json.dumps(data, ensure_ascii=False, indent=2)
|
||||
return ToolResult(output=text)
|
||||
|
||||
|
||||
def _fail_response(msg: str) -> ToolResult:
|
||||
"""创建失败的工具结果"""
|
||||
return ToolResult(error=msg)
|
||||
|
||||
|
||||
# 将辅助方法添加到 BaseTool
|
||||
BaseTool.success_response = staticmethod(_success_response)
|
||||
BaseTool.fail_response = staticmethod(_fail_response)
|
||||
158
MeetSpot/app/tool/file_operators.py
Normal file
158
MeetSpot/app/tool/file_operators.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""File operation interfaces and implementations for local and sandbox environments."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Protocol, Tuple, Union, runtime_checkable
|
||||
|
||||
from app.config import SandboxSettings
|
||||
from app.exceptions import ToolError
|
||||
from app.sandbox.client import SANDBOX_CLIENT
|
||||
|
||||
|
||||
PathLike = Union[str, Path]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FileOperator(Protocol):
|
||||
"""Interface for file operations in different environments."""
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a file."""
|
||||
...
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a file."""
|
||||
...
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory."""
|
||||
...
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists."""
|
||||
...
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a shell command and return (return_code, stdout, stderr)."""
|
||||
...
|
||||
|
||||
|
||||
class LocalFileOperator(FileOperator):
|
||||
"""File operations implementation for local filesystem."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a local file."""
|
||||
try:
|
||||
return Path(path).read_text(encoding=self.encoding)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path}: {str(e)}") from None
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a local file."""
|
||||
try:
|
||||
Path(path).write_text(content, encoding=self.encoding)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path}: {str(e)}") from None
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory."""
|
||||
return Path(path).is_dir()
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists."""
|
||||
return Path(path).exists()
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a shell command locally."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
stdout.decode(),
|
||||
stderr.decode(),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
|
||||
|
||||
class SandboxFileOperator(FileOperator):
|
||||
"""File operations implementation for sandbox environment."""
|
||||
|
||||
def __init__(self):
|
||||
self.sandbox_client = SANDBOX_CLIENT
|
||||
|
||||
async def _ensure_sandbox_initialized(self):
|
||||
"""Ensure sandbox is initialized."""
|
||||
if not self.sandbox_client.sandbox:
|
||||
await self.sandbox_client.create(config=SandboxSettings())
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a file in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
return await self.sandbox_client.read_file(str(path))
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path} in sandbox: {str(e)}") from None
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a file in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
await self.sandbox_client.write_file(str(path), content)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path} in sandbox: {str(e)}") from None
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
result = await self.sandbox_client.run_command(
|
||||
f"test -d {path} && echo 'true' || echo 'false'"
|
||||
)
|
||||
return result.strip() == "true"
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
result = await self.sandbox_client.run_command(
|
||||
f"test -e {path} && echo 'true' || echo 'false'"
|
||||
)
|
||||
return result.strip() == "true"
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a command in sandbox environment."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
stdout = await self.sandbox_client.run_command(
|
||||
cmd, timeout=int(timeout) if timeout else None
|
||||
)
|
||||
return (
|
||||
0, # Always return 0 since we don't have explicit return code from sandbox
|
||||
stdout,
|
||||
"", # No stderr capture in the current sandbox implementation
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds in sandbox"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
return 1, "", f"Error executing command in sandbox: {str(exc)}"
|
||||
3615
MeetSpot/app/tool/meetspot_recommender.py
Normal file
3615
MeetSpot/app/tool/meetspot_recommender.py
Normal file
File diff suppressed because it is too large
Load Diff
64
MeetSpot/app/tool/tool_collection.py
Normal file
64
MeetSpot/app/tool/tool_collection.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from app.exceptions import ToolError
|
||||
from app.tool.base import BaseTool, ToolFailure, ToolResult
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of defined tools."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, *tools: BaseTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
def to_params(self) -> List[Dict[str, Any]]:
|
||||
return [tool.to_param() for tool in self.tools]
|
||||
|
||||
async def execute(self, name: str, tool_input: Union[str, dict]) -> ToolResult:
|
||||
"""Execute a tool by name with given input."""
|
||||
tool = self.get_tool(name)
|
||||
if not tool:
|
||||
return ToolResult(error=f"Tool '{name}' not found")
|
||||
|
||||
# 确保 tool_input 是字典类型
|
||||
if isinstance(tool_input, str):
|
||||
try:
|
||||
tool_input = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return ToolResult(error=f"Invalid tool input format: {tool_input}")
|
||||
|
||||
result = await tool(**tool_input)
|
||||
return result
|
||||
|
||||
async def execute_all(self) -> List[ToolResult]:
|
||||
"""Execute all tools in the collection sequentially."""
|
||||
results = []
|
||||
for tool in self.tools:
|
||||
try:
|
||||
result = await tool()
|
||||
results.append(result)
|
||||
except ToolError as e:
|
||||
results.append(ToolFailure(error=e.message))
|
||||
return results
|
||||
|
||||
def get_tool(self, name: str) -> BaseTool:
|
||||
return self.tool_map.get(name)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
self.tools += (tool,)
|
||||
self.tool_map[tool.name] = tool
|
||||
return self
|
||||
|
||||
def add_tools(self, *tools: BaseTool):
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
return self
|
||||
101
MeetSpot/app/tool/web_search.py
Normal file
101
MeetSpot/app/tool/web_search.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.config import config
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.search import (
|
||||
BaiduSearchEngine,
|
||||
BingSearchEngine,
|
||||
DuckDuckGoSearchEngine,
|
||||
GoogleSearchEngine,
|
||||
WebSearchEngine,
|
||||
)
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
name: str = "web_search"
|
||||
description: str = """Perform a web search and return a list of relevant links.
|
||||
This function attempts to use the primary search engine API to get up-to-date results.
|
||||
If an error occurs, it falls back to an alternative search engine."""
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "(required) The search query to submit to the search engine.",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "(optional) The number of search results to return. Default is 10.",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
_search_engine: dict[str, WebSearchEngine] = {
|
||||
"google": GoogleSearchEngine(),
|
||||
"baidu": BaiduSearchEngine(),
|
||||
"duckduckgo": DuckDuckGoSearchEngine(),
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, query: str, num_results: int = 10) -> List[str]:
|
||||
"""
|
||||
Execute a Web search and return a list of URLs.
|
||||
|
||||
Args:
|
||||
query (str): The search query to submit to the search engine.
|
||||
num_results (int, optional): The number of search results to return. Default is 10.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of URLs matching the search query.
|
||||
"""
|
||||
engine_order = self._get_engine_order()
|
||||
for engine_name in engine_order:
|
||||
engine = self._search_engine[engine_name]
|
||||
try:
|
||||
links = await self._perform_search_with_engine(
|
||||
engine, query, num_results
|
||||
)
|
||||
if links:
|
||||
return links
|
||||
except Exception as e:
|
||||
print(f"Search engine '{engine_name}' failed with error: {e}")
|
||||
return []
|
||||
|
||||
def _get_engine_order(self) -> List[str]:
|
||||
"""
|
||||
Determines the order in which to try search engines.
|
||||
Preferred engine is first (based on configuration), followed by the remaining engines.
|
||||
|
||||
Returns:
|
||||
List[str]: Ordered list of search engine names.
|
||||
"""
|
||||
preferred = "google"
|
||||
if config.search_config and config.search_config.engine:
|
||||
preferred = config.search_config.engine.lower()
|
||||
|
||||
engine_order = []
|
||||
if preferred in self._search_engine:
|
||||
engine_order.append(preferred)
|
||||
for key in self._search_engine:
|
||||
if key not in engine_order:
|
||||
engine_order.append(key)
|
||||
return engine_order
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
)
|
||||
async def _perform_search_with_engine(
|
||||
self,
|
||||
engine: WebSearchEngine,
|
||||
query: str,
|
||||
num_results: int,
|
||||
) -> List[str]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, lambda: list(engine.perform_search(query, num_results=num_results))
|
||||
)
|
||||
Reference in New Issue
Block a user