跳转至

第八章:生产实践

系统架构

生产级架构

                    ┌─────────────┐
                    │   负载均衡   │
                    └──────┬──────┘
          ┌────────────────┼────────────────┐
          │                │                │
    ┌─────▼─────┐    ┌─────▼─────┐    ┌─────▼─────┐
    │  API 服务  │    │  API 服务  │    │  API 服务  │
    └─────┬─────┘    └─────┬─────┘    └─────┬─────┘
          │                │                │
          └────────────────┼────────────────┘
          ┌────────────────┼────────────────┐
          │                │                │
    ┌─────▼─────┐    ┌─────▼─────┐    ┌─────▼─────┐
    │  检索服务  │    │  向量存储  │    │  缓存层   │
    └───────────┘    └───────────┘    └───────────┘

性能优化

批量处理

from typing import List
import asyncio

class BatchRetriever:
    def __init__(self, retriever, batch_size: int = 32):
        self.retriever = retriever
        self.batch_size = batch_size

    async def batch_retrieve(self, queries: List[str]) -> List[List]:
        """批量检索"""
        results = []

        for i in range(0, len(queries), self.batch_size):
            batch = queries[i:i + self.batch_size]
            batch_results = await asyncio.gather(*[
                self._retrieve_one(q) for q in batch
            ])
            results.extend(batch_results)

        return results

    async def _retrieve_one(self, query: str):
        """单个检索"""
        return self.retriever.invoke(query)

# 使用示例
retriever = BatchRetriever(hybrid_retriever, batch_size=16)
results = await retriever.batch_retrieve(queries)

缓存策略

from functools import lru_cache
import hashlib
from typing import List
import redis

class CachedRetriever:
    def __init__(self, retriever, redis_url: str = "redis://localhost:6379"):
        self.retriever = retriever
        self.redis = redis.from_url(redis_url)
        self.cache_ttl = 3600  # 1 小时

    def _get_cache_key(self, query: str) -> str:
        """生成缓存键"""
        return f"retriever:{hashlib.md5(query.encode()).hexdigest()}"

    def retrieve(self, query: str) -> List:
        """带缓存的检索"""
        cache_key = self._get_cache_key(query)

        # 尝试从缓存获取
        cached = self.redis.get(cache_key)
        if cached:
            return json.loads(cached)

        # 执行检索
        results = self.retriever.invoke(query)

        # 存入缓存
        self.redis.setex(
            cache_key,
            self.cache_ttl,
            json.dumps([{'content': r.page_content, 'metadata': r.metadata} for r in results])
        )

        return results

    def invalidate_cache(self, query: str = None):
        """清除缓存"""
        if query:
            self.redis.delete(self._get_cache_key(query))
        else:
            # 清除所有检索缓存
            for key in self.redis.scan_iter("retriever:*"):
                self.redis.delete(key)

异步处理

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import asyncio

app = FastAPI()

class SearchRequest(BaseModel):
    query: str
    top_k: int = 5
    use_rerank: bool = True

class SearchResult(BaseModel):
    content: str
    score: float
    metadata: dict

class SearchResponse(BaseModel):
    results: List[SearchResult]
    latency_ms: float

@app.post("/search", response_model=SearchResponse)
async def search(request: SearchRequest):
    """异步搜索接口"""
    import time
    start = time.time()

    # 异步执行检索
    results = await asyncio.to_thread(
        hybrid_retriever.invoke,
        request.query
    )

    # 可选重排序
    if request.use_rerank:
        results = await asyncio.to_thread(
            reranker.rerank,
            request.query,
            results
        )

    latency = (time.time() - start) * 1000

    return SearchResponse(
        results=[SearchResult(
            content=r.page_content,
            score=r.metadata.get('score', 0),
            metadata=r.metadata
        ) for r in results[:request.top_k]],
        latency_ms=latency
    )

可扩展性

水平扩展

from concurrent.futures import ThreadPoolExecutor
from typing import List

class DistributedRetriever:
    def __init__(self, retriever_urls: List[str]):
        self.retriever_urls = retriever_urls
        self.executor = ThreadPoolExecutor(max_workers=len(retriever_urls))

    def retrieve_from_node(self, url: str, query: str) -> List:
        """从单个节点检索"""
        import requests
        response = requests.post(
            f"{url}/search",
            json={"query": query, "top_k": 20}
        )
        return response.json()['results']

    def retrieve(self, query: str) -> List:
        """分布式检索"""
        # 并行从所有节点检索
        futures = [
            self.executor.submit(self.retrieve_from_node, url, query)
            for url in self.retriever_urls
        ]

        # 收集结果
        all_results = []
        for future in futures:
            all_results.extend(future.result())

        # 合并去重
        seen = set()
        unique_results = []
        for r in all_results:
            key = r['content'][:100]  # 使用前100字符去重
            if key not in seen:
                seen.add(key)
                unique_results.append(r)

        return unique_results

分片策略

class ShardedVectorStore:
    def __init__(self, shards: List):
        self.shards = shards

    def get_shard(self, doc_id: str) -> int:
        """根据文档 ID 确定分片"""
        return hash(doc_id) % len(self.shards)

    def add_documents(self, documents: List):
        """添加文档到分片"""
        for doc in documents:
            shard_idx = self.get_shard(doc.metadata['id'])
            self.shards[shard_idx].add_documents([doc])

    def search(self, query: str, top_k: int = 10) -> List:
        """搜索所有分片"""
        all_results = []

        for shard in self.shards:
            results = shard.similarity_search(query, k=top_k * 2)
            all_results.extend(results)

        # 合并排序
        all_results.sort(key=lambda x: x.metadata.get('score', 0), reverse=True)
        return all_results[:top_k]

监控与日志

监控指标

from prometheus_client import Counter, Histogram, Gauge
import time

# 定义指标
SEARCH_COUNT = Counter('search_requests_total', 'Total search requests')
SEARCH_LATENCY = Histogram('search_latency_seconds', 'Search latency')
SEARCH_RESULTS = Histogram('search_results_count', 'Number of results returned')
CACHE_HITS = Counter('cache_hits_total', 'Cache hit count')
CACHE_MISSES = Counter('cache_misses_total', 'Cache miss count')

class MonitoredRetriever:
    def __init__(self, retriever, cache=None):
        self.retriever = retriever
        self.cache = cache

    @SEARCH_LATENCY.time()
    def retrieve(self, query: str, top_k: int = 5) -> List:
        """带监控的检索"""
        SEARCH_COUNT.inc()

        # 检查缓存
        if self.cache:
            cached = self.cache.get(query)
            if cached:
                CACHE_HITS.inc()
                return cached[:top_k]
            CACHE_MISSES.inc()

        # 执行检索
        results = self.retriever.invoke(query)

        # 记录结果数量
        SEARCH_RESULTS.observe(len(results))

        return results[:top_k]

日志记录

import logging
import json
from datetime import datetime

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('retriever')

class LoggedRetriever:
    def __init__(self, retriever):
        self.retriever = retriever

    def retrieve(self, query: str, top_k: int = 5) -> List:
        """带日志的检索"""
        start_time = time.time()

        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'query': query,
            'top_k': top_k
        }

        try:
            results = self.retriever.invoke(query)

            log_entry['status'] = 'success'
            log_entry['result_count'] = len(results)
            log_entry['latency_ms'] = (time.time() - start_time) * 1000

            logger.info(json.dumps(log_entry))

            return results[:top_k]

        except Exception as e:
            log_entry['status'] = 'error'
            log_entry['error'] = str(e)
            log_entry['latency_ms'] = (time.time() - start_time) * 1000

            logger.error(json.dumps(log_entry))
            raise

错误处理

重试机制

from tenacity import retry, stop_after_attempt, wait_exponential
import requests

class RobustRetriever:
    def __init__(self, retriever):
        self.retriever = retriever

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=1, max=10)
    )
    def retrieve(self, query: str) -> List:
        """带重试的检索"""
        try:
            return self.retriever.invoke(query)
        except Exception as e:
            logger.warning(f"检索失败,准备重试: {e}")
            raise

    def retrieve_with_fallback(self, query: str) -> List:
        """带降级的检索"""
        try:
            return self.retrieve(query)
        except Exception as e:
            logger.error(f"检索失败,使用降级策略: {e}")
            # 降级到简单的关键词搜索
            return self.fallback_search(query)

    def fallback_search(self, query: str) -> List:
        """降级搜索"""
        # 使用 BM25 或其他简单方法
        return self.bm25_retriever.invoke(query)

熔断器

from datetime import datetime, timedelta
from typing import Optional

class CircuitBreaker:
    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: int = 60
    ):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.failures = 0
        self.last_failure: Optional[datetime] = None
        self.state = 'closed'  # closed, open, half-open

    def record_failure(self):
        """记录失败"""
        self.failures += 1
        self.last_failure = datetime.now()

        if self.failures >= self.failure_threshold:
            self.state = 'open'

    def record_success(self):
        """记录成功"""
        self.failures = 0
        self.state = 'closed'

    def can_execute(self) -> bool:
        """检查是否可以执行"""
        if self.state == 'closed':
            return True

        if self.state == 'open':
            # 检查是否可以尝试恢复
            if datetime.now() - self.last_failure > timedelta(seconds=self.recovery_timeout):
                self.state = 'half-open'
                return True
            return False

        # half-open 状态允许一次尝试
        return True

class CircuitBreakerRetriever:
    def __init__(self, retriever):
        self.retriever = retriever
        self.circuit_breaker = CircuitBreaker()

    def retrieve(self, query: str) -> List:
        """带熔断器的检索"""
        if not self.circuit_breaker.can_execute():
            raise Exception("熔断器打开,服务不可用")

        try:
            results = self.retriever.invoke(query)
            self.circuit_breaker.record_success()
            return results
        except Exception as e:
            self.circuit_breaker.record_failure()
            raise

部署配置

Docker 部署

# Dockerfile
FROM python:3.11-slim

WORKDIR /app

# 安装依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 复制代码
COPY . .

# 暴露端口
EXPOSE 8000

# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

Docker Compose

# docker-compose.yml
version: '3.8'

services:
  api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - REDIS_URL=redis://redis:6379
    depends_on:
      - redis
      - milvus
    deploy:
      replicas: 3
      resources:
        limits:
          cpus: '2'
          memory: 4G

  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"
    volumes:
      - redis_data:/data

  milvus:
    image: milvusdb/milvus:latest
    ports:
      - "19530:19530"
    volumes:
      - milvus_data:/var/lib/milvus

volumes:
  redis_data:
  milvus_data:

Kubernetes 部署

# kubernetes/deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: retriever-api
spec:
  replicas: 3
  selector:
    matchLabels:
      app: retriever-api
  template:
    metadata:
      labels:
        app: retriever-api
    spec:
      containers:
      - name: api
        image: retriever-api:latest
        ports:
        - containerPort: 8000
        resources:
          requests:
            memory: "2Gi"
            cpu: "1"
          limits:
            memory: "4Gi"
            cpu: "2"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /ready
            port: 8000
          initialDelaySeconds: 5
          periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
  name: retriever-api
spec:
  selector:
    app: retriever-api
  ports:
  - port: 80
    targetPort: 8000
  type: LoadBalancer

小结

本章学习了:

  • ✅ 生产级系统架构
  • ✅ 性能优化策略
  • ✅ 可扩展性设计
  • ✅ 监控与日志
  • ✅ 错误处理机制
  • ✅ 部署配置

总结

恭喜完成 RAG 进阶教程!你已掌握:

  1. 混合检索:结合 BM25 和向量检索的优势
  2. 重排序优化:使用 Cross-Encoder 和 LLM 提升检索质量
  3. 多模态检索:支持文本、图像、音频、视频检索
  4. 检索评估:使用多种指标评估检索效果
  5. 生产实践:构建高性能、可扩展的检索系统

继续探索更多 RAG 技术,构建更强大的智能检索系统!