第七章:拦截器与中间件¶
服务端拦截器¶
Python 拦截器¶
from grpc import UnaryUnaryServerInterceptor, StatusCode
import time
import logging
# 日志拦截器
class LoggingInterceptor(UnaryUnaryServerInterceptor):
def intercept_service(self, continuation, handler_call_details):
method = handler_call_details.method
logging.info(f"Request: {method}")
start_time = time.time()
response = continuation(handler_call_details)
duration = time.time() - start_time
logging.info(f"Response: {method} ({duration:.3f}s)")
return response
# 认证拦截器
class AuthInterceptor(UnaryUnaryServerInterceptor):
def intercept_service(self, continuation, handler_call_details):
metadata = handler_call_details.invocation_metadata
# 检查 token
token = None
for key, value in metadata:
if key == 'authorization':
token = value
break
if not token or not self._validate_token(token):
return lambda request, context: (
context.set_code(StatusCode.UNAUTHENTICATED),
context.set_details('Invalid token'),
)
return continuation(handler_call_details)
def _validate_token(self, token):
# 验证 token
return token == 'valid-token'
# 使用拦截器
from concurrent import futures
import grpc
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
interceptors=[
LoggingInterceptor(),
AuthInterceptor(),
],
)
Go 拦截器¶
// 日志拦截器
func loggingInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
log.Printf("Request: %s", info.FullMethod)
start := time.Now()
resp, err := handler(ctx, req)
duration := time.Since(start)
if err != nil {
log.Printf("Error: %s (%v) - %v", info.FullMethod, duration, err)
} else {
log.Printf("Response: %s (%v)", info.FullMethod, duration)
}
return resp, err
}
// 认证拦截器
func authInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 获取 metadata
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
}
// 获取 token
tokens := md.Get("authorization")
if len(tokens) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "missing token")
}
token := tokens[0]
if !validateToken(token) {
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
return handler(ctx, req)
}
func validateToken(token string) bool {
return token == "valid-token"
}
// 使用拦截器
server := grpc.NewServer(
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
loggingInterceptor,
authInterceptor,
)),
)
客户端拦截器¶
Python 客户端拦截器¶
from grpc import UnaryUnaryClientInterceptor
# 日志拦截器
class LoggingInterceptor(UnaryUnaryClientInterceptor):
def intercept_unary_unary(self, continuation, client_call_details, request):
method = client_call_details.method
logging.info(f"Request: {method}")
start_time = time.time()
response = continuation(client_call_details, request)
duration = time.time() - start_time
logging.info(f"Response: {method} ({duration:.3f}s)")
return response
# 重试拦截器
class RetryInterceptor(UnaryUnaryClientInterceptor):
def __init__(self, max_retries=3):
self.max_retries = max_retries
def intercept_unary_unary(self, continuation, client_call_details, request):
for attempt in range(self.max_retries):
response = continuation(client_call_details, request)
if response.code() == StatusCode.OK:
return response
logging.warning(f"Retry {attempt + 1}/{self.max_retries}")
time.sleep(2 ** attempt) # 指数退避
return response
# 使用拦截器
channel = grpc.intercept_channel(
grpc.insecure_channel('localhost:50051'),
LoggingInterceptor(),
RetryInterceptor(),
)
Go 客户端拦截器¶
// 日志拦截器
func loggingInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
log.Printf("Request: %s", method)
start := time.Now()
err := invoker(ctx, method, req, reply, cc, opts...)
duration := time.Since(start)
if err != nil {
log.Printf("Error: %s (%v) - %v", method, duration, err)
} else {
log.Printf("Response: %s (%v)", method, duration)
}
return err
}
// 重试拦截器
func retryInterceptor(maxRetries int) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
err := invoker(ctx, method, req, reply, cc, opts...)
if err == nil {
return nil
}
lastErr = err
log.Printf("Retry %d/%d: %v", attempt+1, maxRetries, err)
// 指数退避
time.Sleep(time.Duration(1<<uint(attempt)) * time.Second)
}
return lastErr
}
}
// 使用拦截器
conn, err := grpc.Dial(
"localhost:50051",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
loggingInterceptor,
retryInterceptor(3),
)),
)
中间件库¶
grpc-go-middleware¶
import (
"github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
"github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/grpc-ecosystem/go-grpc-middleware/ratelimit"
)
// 创建服务器
server := grpc.NewServer(
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc_recovery.UnaryServerInterceptor(), // 恢复 panic
grpc_zap.UnaryServerInterceptor(logger), // 日志
grpc_ratelimit.UnaryServerInterceptor(limiter), // 限流
)),
)
grpc-interceptor¶
from grpc_interceptor import ServerInterceptor
class LoggingInterceptor(ServerInterceptor):
def intercept(self, method, request, context, method_name):
logging.info(f"Request: {method_name}")
try:
response = method(request, context)
logging.info(f"Response: {method_name}")
return response
except Exception as e:
logging.error(f"Error: {method_name} - {e}")
raise
小结¶
拦截器与中间件要点:
- 服务端拦截器:日志、认证、限流
- 客户端拦截器:日志、重试
- 中间件库:grpc-go-middleware、grpc-interceptor
- 链式拦截器:组合多个拦截器
下一章我们将学习生产实践。