跳转至

第四章:自定义工具开发

工具设计原则

1. 单一职责

每个工具应该只做一件事:

# ❌ 不好的设计:一个工具做太多事
{
    "name": "manage_database",
    "description": "管理数据库,支持增删改查"
}

# ✅ 好的设计:职责分离
{
    "name": "query_database",
    "description": "查询数据库"
}
{
    "name": "insert_record",
    "description": "插入记录"
}
{
    "name": "update_record",
    "description": "更新记录"
}

2. 清晰的输入输出

def search_products(query: str, category: str = None, limit: int = 10) -> dict:
    """
    搜索商品

    Args:
        query: 搜索关键词
        category: 商品类别(可选)
        limit: 返回数量限制

    Returns:
        {
            "success": True,
            "results": [...],
            "total": 100,
            "message": "找到 100 个结果"
        }
    """
    return {
        "success": True,
        "results": [...],
        "total": 100,
        "message": f"找到 100 个匹配 '{query}' 的结果"
    }

3. 错误处理

def get_user_info(user_id: str) -> dict:
    """获取用户信息"""
    try:
        user = db.get_user(user_id)
        if not user:
            return {
                "success": False,
                "error": "USER_NOT_FOUND",
                "message": f"用户 {user_id} 不存在"
            }

        return {
            "success": True,
            "user": {
                "id": user.id,
                "name": user.name,
                "email": user.email
            }
        }
    except DatabaseError as e:
        return {
            "success": False,
            "error": "DATABASE_ERROR",
            "message": str(e)
        }

常见工具类型

1. 数据查询工具

def query_database(
    table: str,
    columns: list = None,
    where: dict = None,
    order_by: str = None,
    limit: int = 100
) -> dict:
    """
    查询数据库

    Args:
        table: 表名
        columns: 要查询的列,默认所有列
        where: 查询条件
        order_by: 排序字段
        limit: 返回数量限制
    """
    # 构建查询
    query = f"SELECT {', '.join(columns) if columns else '*'} FROM {table}"

    if where:
        conditions = [f"{k} = %s" for k in where.keys()]
        query += f" WHERE {' AND '.join(conditions)}"

    if order_by:
        query += f" ORDER BY {order_by}"

    query += f" LIMIT {limit}"

    # 执行查询
    try:
        results = db.execute(query, list(where.values()) if where else None)
        return {
            "success": True,
            "data": results,
            "count": len(results)
        }
    except Exception as e:
        return {"success": False, "error": str(e)}


# 工具定义
{
    "type": "function",
    "function": {
        "name": "query_database",
        "description": "查询数据库表数据",
        "parameters": {
            "type": "object",
            "properties": {
                "table": {
                    "type": "string",
                    "description": "表名",
                    "enum": ["users", "orders", "products"]  # 限制可查询的表
                },
                "columns": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "要查询的列名"
                },
                "where": {
                    "type": "object",
                    "description": "查询条件,键值对形式"
                },
                "limit": {
                    "type": "integer",
                    "description": "返回数量限制",
                    "default": 100,
                    "maximum": 1000
                }
            },
            "required": ["table"]
        }
    }
}

2. API 调用工具

import requests
from typing import Optional

def call_external_api(
    service: str,
    endpoint: str,
    method: str = "GET",
    params: dict = None,
    data: dict = None
) -> dict:
    """
    调用外部 API

    Args:
        service: 服务名称
        endpoint: API 端点
        method: HTTP 方法
        params: URL 参数
        data: 请求体数据
    """
    # API 配置
    api_configs = {
        "weather": {
            "base_url": "https://api.weather.com",
            "api_key": "your-weather-api-key"
        },
        "stock": {
            "base_url": "https://api.stock.com",
            "api_key": "your-stock-api-key"
        }
    }

    if service not in api_configs:
        return {"success": False, "error": f"未知服务: {service}"}

    config = api_configs[service]
    url = f"{config['base_url']}/{endpoint}"

    headers = {
        "Authorization": f"Bearer {config['api_key']}",
        "Content-Type": "application/json"
    }

    try:
        response = requests.request(
            method=method,
            url=url,
            headers=headers,
            params=params,
            json=data,
            timeout=30
        )
        response.raise_for_status()

        return {
            "success": True,
            "data": response.json()
        }
    except requests.RequestException as e:
        return {"success": False, "error": str(e)}

3. 文件操作工具

import os
from pathlib import Path

def read_file(file_path: str) -> dict:
    """读取文件内容"""
    try:
        # 安全检查:防止路径遍历攻击
        safe_path = Path("/allowed/directory") / file_path
        safe_path.resolve().relative_to(Path("/allowed/directory").resolve())

        if not safe_path.exists():
            return {"success": False, "error": "文件不存在"}

        with open(safe_path, 'r', encoding='utf-8') as f:
            content = f.read()

        return {
            "success": True,
            "content": content,
            "size": len(content),
            "path": str(safe_path)
        }
    except ValueError:
        return {"success": False, "error": "非法路径"}
    except Exception as e:
        return {"success": False, "error": str(e)}


def write_file(file_path: str, content: str) -> dict:
    """写入文件"""
    try:
        safe_path = Path("/allowed/directory") / file_path
        safe_path.resolve().relative_to(Path("/allowed/directory").resolve())

        # 创建目录
        safe_path.parent.mkdir(parents=True, exist_ok=True)

        with open(safe_path, 'w', encoding='utf-8') as f:
            f.write(content)

        return {
            "success": True,
            "message": f"文件已保存到 {safe_path}",
            "size": len(content)
        }
    except ValueError:
        return {"success": False, "error": "非法路径"}
    except Exception as e:
        return {"success": False, "error": str(e)}

4. 计算工具

def calculate(expression: str) -> dict:
    """
    安全地计算数学表达式

    Args:
        expression: 数学表达式,如 "2 + 3 * 4"
    """
    import ast
    import operator

    # 允许的操作符
    operators = {
        ast.Add: operator.add,
        ast.Sub: operator.sub,
        ast.Mult: operator.mul,
        ast.Div: operator.truediv,
        ast.Pow: operator.pow,
        ast.Mod: operator.mod,
        ast.USub: operator.neg,
    }

    def eval_node(node):
        if isinstance(node, ast.Num):
            return node.n
        elif isinstance(node, ast.Constant):
            if isinstance(node.value, (int, float)):
                return node.value
        elif isinstance(node, ast.BinOp):
            left = eval_node(node.left)
            right = eval_node(node.right)
            op = operators.get(type(node.op))
            if op:
                return op(left, right)
        elif isinstance(node, ast.UnaryOp):
            operand = eval_node(node.operand)
            op = operators.get(type(node.op))
            if op:
                return op(operand)

        raise ValueError(f"不支持的表达式: {ast.dump(node)}")

    try:
        tree = ast.parse(expression, mode='eval')
        result = eval_node(tree.body)
        return {
            "success": True,
            "expression": expression,
            "result": result
        }
    except Exception as e:
        return {
            "success": False,
            "error": f"计算错误: {str(e)}"
        }

工具装饰器

简化工具定义:

from functools import wraps
import inspect

# 全局工具注册表
TOOL_REGISTRY = []


def tool(description: str = None, **param_descriptions):
    """
    工具装饰器

    使用示例:
    @tool(description="获取天气信息", city="城市名称")
    def get_weather(city: str) -> dict:
        ...
    """
    def decorator(func):
        # 获取函数签名
        sig = inspect.signature(func)
        parameters = {"type": "object", "properties": {}, "required": []}

        for name, param in sig.parameters.items():
            param_type = "string"  # 默认类型
            if param.annotation == int:
                param_type = "integer"
            elif param.annotation == float:
                param_type = "number"
            elif param.annotation == bool:
                param_type = "boolean"
            elif param.annotation == list:
                param_type = "array"
            elif param.annotation == dict:
                param_type = "object"

            prop = {"type": param_type}
            if name in param_descriptions:
                prop["description"] = param_descriptions[name]

            parameters["properties"][name] = prop

            if param.default == inspect.Parameter.empty:
                parameters["required"].append(name)

        # 注册工具
        tool_def = {
            "type": "function",
            "function": {
                "name": func.__name__,
                "description": description or func.__doc__ or "",
                "parameters": parameters
            }
        }

        TOOL_REGISTRY.append((tool_def, func))

        @wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        wrapper.tool_definition = tool_def
        return wrapper

    return decorator


# 使用示例
@tool(
    description="获取指定城市的天气信息",
    city="城市名称,如:北京、上海",
    unit="温度单位:celsius 或 fahrenheit"
)
def get_weather(city: str, unit: str = "celsius") -> dict:
    """获取天气"""
    return {"city": city, "temperature": 22, "unit": unit}


# 获取所有工具定义
tools = [item[0] for item in TOOL_REGISTRY]
functions_map = {item[0]["function"]["name"]: item[1] for item in TOOL_REGISTRY}

异步工具

import asyncio
import aiohttp

async def async_http_request(url: str, method: str = "GET") -> dict:
    """异步 HTTP 请求"""
    async with aiohttp.ClientSession() as session:
        async with session.request(method, url) as response:
            return await response.json()


class AsyncToolExecutor:
    """异步工具执行器"""

    def __init__(self):
        self.tools = {}
        self.tool_definitions = []

    def register(self, definition: dict, func):
        name = definition["function"]["name"]
        self.tools[name] = func
        self.tool_definitions.append(definition)

    async def execute(self, tool_call) -> Any:
        name = tool_call.function.name
        args = json.loads(tool_call.function.arguments)

        func = self.tools[name]

        # 支持同步和异步函数
        if asyncio.iscoroutinefunction(func):
            result = await func(**args)
        else:
            result = func(**args)

        return result

    async def execute_all(self, tool_calls: list) -> list:
        """并行执行多个工具"""
        tasks = [self.execute(tc) for tc in tool_calls]
        return await asyncio.gather(*tasks)


# 使用示例
async def main():
    executor = AsyncToolExecutor()
    executor.register(
        {
            "type": "function",
            "function": {
                "name": "fetch_url",
                "description": "获取 URL 内容",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "url": {"type": "string"}
                    },
                    "required": ["url"]
                }
            }
        },
        async_http_request
    )

    # 并行执行多个请求
    results = await executor.execute_all(tool_calls)

小结

本章学习了:

  • ✅ 工具设计原则
  • ✅ 常见工具类型实现
  • ✅ 工具装饰器简化定义
  • ✅ 异步工具支持

下一章

第五章:工具调用模式 - 学习高级的工具调用模式。