第八章:生产实践¶
系统架构¶
生产级架构¶
┌─────────────┐
│ 负载均衡 │
└──────┬──────┘
│
┌────────────────┼────────────────┐
│ │ │
┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐
│ API 服务 │ │ API 服务 │ │ API 服务 │
└─────┬─────┘ └─────┬─────┘ └─────┬─────┘
│ │ │
└────────────────┼────────────────┘
│
┌────────────────┼────────────────┐
│ │ │
┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐
│ 检索服务 │ │ 向量存储 │ │ 缓存层 │
└───────────┘ └───────────┘ └───────────┘
性能优化¶
批量处理¶
from typing import List
import asyncio
class BatchRetriever:
def __init__(self, retriever, batch_size: int = 32):
self.retriever = retriever
self.batch_size = batch_size
async def batch_retrieve(self, queries: List[str]) -> List[List]:
"""批量检索"""
results = []
for i in range(0, len(queries), self.batch_size):
batch = queries[i:i + self.batch_size]
batch_results = await asyncio.gather(*[
self._retrieve_one(q) for q in batch
])
results.extend(batch_results)
return results
async def _retrieve_one(self, query: str):
"""单个检索"""
return self.retriever.invoke(query)
# 使用示例
retriever = BatchRetriever(hybrid_retriever, batch_size=16)
results = await retriever.batch_retrieve(queries)
缓存策略¶
from functools import lru_cache
import hashlib
from typing import List
import redis
class CachedRetriever:
def __init__(self, retriever, redis_url: str = "redis://localhost:6379"):
self.retriever = retriever
self.redis = redis.from_url(redis_url)
self.cache_ttl = 3600 # 1 小时
def _get_cache_key(self, query: str) -> str:
"""生成缓存键"""
return f"retriever:{hashlib.md5(query.encode()).hexdigest()}"
def retrieve(self, query: str) -> List:
"""带缓存的检索"""
cache_key = self._get_cache_key(query)
# 尝试从缓存获取
cached = self.redis.get(cache_key)
if cached:
return json.loads(cached)
# 执行检索
results = self.retriever.invoke(query)
# 存入缓存
self.redis.setex(
cache_key,
self.cache_ttl,
json.dumps([{'content': r.page_content, 'metadata': r.metadata} for r in results])
)
return results
def invalidate_cache(self, query: str = None):
"""清除缓存"""
if query:
self.redis.delete(self._get_cache_key(query))
else:
# 清除所有检索缓存
for key in self.redis.scan_iter("retriever:*"):
self.redis.delete(key)
异步处理¶
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import asyncio
app = FastAPI()
class SearchRequest(BaseModel):
query: str
top_k: int = 5
use_rerank: bool = True
class SearchResult(BaseModel):
content: str
score: float
metadata: dict
class SearchResponse(BaseModel):
results: List[SearchResult]
latency_ms: float
@app.post("/search", response_model=SearchResponse)
async def search(request: SearchRequest):
"""异步搜索接口"""
import time
start = time.time()
# 异步执行检索
results = await asyncio.to_thread(
hybrid_retriever.invoke,
request.query
)
# 可选重排序
if request.use_rerank:
results = await asyncio.to_thread(
reranker.rerank,
request.query,
results
)
latency = (time.time() - start) * 1000
return SearchResponse(
results=[SearchResult(
content=r.page_content,
score=r.metadata.get('score', 0),
metadata=r.metadata
) for r in results[:request.top_k]],
latency_ms=latency
)
可扩展性¶
水平扩展¶
from concurrent.futures import ThreadPoolExecutor
from typing import List
class DistributedRetriever:
def __init__(self, retriever_urls: List[str]):
self.retriever_urls = retriever_urls
self.executor = ThreadPoolExecutor(max_workers=len(retriever_urls))
def retrieve_from_node(self, url: str, query: str) -> List:
"""从单个节点检索"""
import requests
response = requests.post(
f"{url}/search",
json={"query": query, "top_k": 20}
)
return response.json()['results']
def retrieve(self, query: str) -> List:
"""分布式检索"""
# 并行从所有节点检索
futures = [
self.executor.submit(self.retrieve_from_node, url, query)
for url in self.retriever_urls
]
# 收集结果
all_results = []
for future in futures:
all_results.extend(future.result())
# 合并去重
seen = set()
unique_results = []
for r in all_results:
key = r['content'][:100] # 使用前100字符去重
if key not in seen:
seen.add(key)
unique_results.append(r)
return unique_results
分片策略¶
class ShardedVectorStore:
def __init__(self, shards: List):
self.shards = shards
def get_shard(self, doc_id: str) -> int:
"""根据文档 ID 确定分片"""
return hash(doc_id) % len(self.shards)
def add_documents(self, documents: List):
"""添加文档到分片"""
for doc in documents:
shard_idx = self.get_shard(doc.metadata['id'])
self.shards[shard_idx].add_documents([doc])
def search(self, query: str, top_k: int = 10) -> List:
"""搜索所有分片"""
all_results = []
for shard in self.shards:
results = shard.similarity_search(query, k=top_k * 2)
all_results.extend(results)
# 合并排序
all_results.sort(key=lambda x: x.metadata.get('score', 0), reverse=True)
return all_results[:top_k]
监控与日志¶
监控指标¶
from prometheus_client import Counter, Histogram, Gauge
import time
# 定义指标
SEARCH_COUNT = Counter('search_requests_total', 'Total search requests')
SEARCH_LATENCY = Histogram('search_latency_seconds', 'Search latency')
SEARCH_RESULTS = Histogram('search_results_count', 'Number of results returned')
CACHE_HITS = Counter('cache_hits_total', 'Cache hit count')
CACHE_MISSES = Counter('cache_misses_total', 'Cache miss count')
class MonitoredRetriever:
def __init__(self, retriever, cache=None):
self.retriever = retriever
self.cache = cache
@SEARCH_LATENCY.time()
def retrieve(self, query: str, top_k: int = 5) -> List:
"""带监控的检索"""
SEARCH_COUNT.inc()
# 检查缓存
if self.cache:
cached = self.cache.get(query)
if cached:
CACHE_HITS.inc()
return cached[:top_k]
CACHE_MISSES.inc()
# 执行检索
results = self.retriever.invoke(query)
# 记录结果数量
SEARCH_RESULTS.observe(len(results))
return results[:top_k]
日志记录¶
import logging
import json
from datetime import datetime
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('retriever')
class LoggedRetriever:
def __init__(self, retriever):
self.retriever = retriever
def retrieve(self, query: str, top_k: int = 5) -> List:
"""带日志的检索"""
start_time = time.time()
log_entry = {
'timestamp': datetime.now().isoformat(),
'query': query,
'top_k': top_k
}
try:
results = self.retriever.invoke(query)
log_entry['status'] = 'success'
log_entry['result_count'] = len(results)
log_entry['latency_ms'] = (time.time() - start_time) * 1000
logger.info(json.dumps(log_entry))
return results[:top_k]
except Exception as e:
log_entry['status'] = 'error'
log_entry['error'] = str(e)
log_entry['latency_ms'] = (time.time() - start_time) * 1000
logger.error(json.dumps(log_entry))
raise
错误处理¶
重试机制¶
from tenacity import retry, stop_after_attempt, wait_exponential
import requests
class RobustRetriever:
def __init__(self, retriever):
self.retriever = retriever
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10)
)
def retrieve(self, query: str) -> List:
"""带重试的检索"""
try:
return self.retriever.invoke(query)
except Exception as e:
logger.warning(f"检索失败,准备重试: {e}")
raise
def retrieve_with_fallback(self, query: str) -> List:
"""带降级的检索"""
try:
return self.retrieve(query)
except Exception as e:
logger.error(f"检索失败,使用降级策略: {e}")
# 降级到简单的关键词搜索
return self.fallback_search(query)
def fallback_search(self, query: str) -> List:
"""降级搜索"""
# 使用 BM25 或其他简单方法
return self.bm25_retriever.invoke(query)
熔断器¶
from datetime import datetime, timedelta
from typing import Optional
class CircuitBreaker:
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 60
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.failures = 0
self.last_failure: Optional[datetime] = None
self.state = 'closed' # closed, open, half-open
def record_failure(self):
"""记录失败"""
self.failures += 1
self.last_failure = datetime.now()
if self.failures >= self.failure_threshold:
self.state = 'open'
def record_success(self):
"""记录成功"""
self.failures = 0
self.state = 'closed'
def can_execute(self) -> bool:
"""检查是否可以执行"""
if self.state == 'closed':
return True
if self.state == 'open':
# 检查是否可以尝试恢复
if datetime.now() - self.last_failure > timedelta(seconds=self.recovery_timeout):
self.state = 'half-open'
return True
return False
# half-open 状态允许一次尝试
return True
class CircuitBreakerRetriever:
def __init__(self, retriever):
self.retriever = retriever
self.circuit_breaker = CircuitBreaker()
def retrieve(self, query: str) -> List:
"""带熔断器的检索"""
if not self.circuit_breaker.can_execute():
raise Exception("熔断器打开,服务不可用")
try:
results = self.retriever.invoke(query)
self.circuit_breaker.record_success()
return results
except Exception as e:
self.circuit_breaker.record_failure()
raise
部署配置¶
Docker 部署¶
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# 安装依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制代码
COPY . .
# 暴露端口
EXPOSE 8000
# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
Docker Compose¶
# docker-compose.yml
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- REDIS_URL=redis://redis:6379
depends_on:
- redis
- milvus
deploy:
replicas: 3
resources:
limits:
cpus: '2'
memory: 4G
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
milvus:
image: milvusdb/milvus:latest
ports:
- "19530:19530"
volumes:
- milvus_data:/var/lib/milvus
volumes:
redis_data:
milvus_data:
Kubernetes 部署¶
# kubernetes/deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: retriever-api
spec:
replicas: 3
selector:
matchLabels:
app: retriever-api
template:
metadata:
labels:
app: retriever-api
spec:
containers:
- name: api
image: retriever-api:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "2Gi"
cpu: "1"
limits:
memory: "4Gi"
cpu: "2"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /ready
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: retriever-api
spec:
selector:
app: retriever-api
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
小结¶
本章学习了:
- ✅ 生产级系统架构
- ✅ 性能优化策略
- ✅ 可扩展性设计
- ✅ 监控与日志
- ✅ 错误处理机制
- ✅ 部署配置
总结¶
恭喜完成 RAG 进阶教程!你已掌握:
- 混合检索:结合 BM25 和向量检索的优势
- 重排序优化:使用 Cross-Encoder 和 LLM 提升检索质量
- 多模态检索:支持文本、图像、音频、视频检索
- 检索评估:使用多种指标评估检索效果
- 生产实践:构建高性能、可扩展的检索系统
继续探索更多 RAG 技术,构建更强大的智能检索系统!