跳转至

第七章:拦截器与中间件

服务端拦截器

Python 拦截器

from grpc import UnaryUnaryServerInterceptor, StatusCode
import time
import logging

# 日志拦截器
class LoggingInterceptor(UnaryUnaryServerInterceptor):
    def intercept_service(self, continuation, handler_call_details):
        method = handler_call_details.method
        logging.info(f"Request: {method}")

        start_time = time.time()
        response = continuation(handler_call_details)
        duration = time.time() - start_time

        logging.info(f"Response: {method} ({duration:.3f}s)")
        return response

# 认证拦截器
class AuthInterceptor(UnaryUnaryServerInterceptor):
    def intercept_service(self, continuation, handler_call_details):
        metadata = handler_call_details.invocation_metadata

        # 检查 token
        token = None
        for key, value in metadata:
            if key == 'authorization':
                token = value
                break

        if not token or not self._validate_token(token):
            return lambda request, context: (
                context.set_code(StatusCode.UNAUTHENTICATED),
                context.set_details('Invalid token'),
            )

        return continuation(handler_call_details)

    def _validate_token(self, token):
        # 验证 token
        return token == 'valid-token'

# 使用拦截器
from concurrent import futures
import grpc

server = grpc.server(
    futures.ThreadPoolExecutor(max_workers=10),
    interceptors=[
        LoggingInterceptor(),
        AuthInterceptor(),
    ],
)

Go 拦截器

// 日志拦截器
func loggingInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    log.Printf("Request: %s", info.FullMethod)

    start := time.Now()
    resp, err := handler(ctx, req)
    duration := time.Since(start)

    if err != nil {
        log.Printf("Error: %s (%v) - %v", info.FullMethod, duration, err)
    } else {
        log.Printf("Response: %s (%v)", info.FullMethod, duration)
    }

    return resp, err
}

// 认证拦截器
func authInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    // 获取 metadata
    md, ok := metadata.FromIncomingContext(ctx)
    if !ok {
        return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
    }

    // 获取 token
    tokens := md.Get("authorization")
    if len(tokens) == 0 {
        return nil, status.Errorf(codes.Unauthenticated, "missing token")
    }

    token := tokens[0]
    if !validateToken(token) {
        return nil, status.Errorf(codes.Unauthenticated, "invalid token")
    }

    return handler(ctx, req)
}

func validateToken(token string) bool {
    return token == "valid-token"
}

// 使用拦截器
server := grpc.NewServer(
    grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
        loggingInterceptor,
        authInterceptor,
    )),
)

客户端拦截器

Python 客户端拦截器

from grpc import UnaryUnaryClientInterceptor

# 日志拦截器
class LoggingInterceptor(UnaryUnaryClientInterceptor):
    def intercept_unary_unary(self, continuation, client_call_details, request):
        method = client_call_details.method
        logging.info(f"Request: {method}")

        start_time = time.time()
        response = continuation(client_call_details, request)
        duration = time.time() - start_time

        logging.info(f"Response: {method} ({duration:.3f}s)")
        return response

# 重试拦截器
class RetryInterceptor(UnaryUnaryClientInterceptor):
    def __init__(self, max_retries=3):
        self.max_retries = max_retries

    def intercept_unary_unary(self, continuation, client_call_details, request):
        for attempt in range(self.max_retries):
            response = continuation(client_call_details, request)

            if response.code() == StatusCode.OK:
                return response

            logging.warning(f"Retry {attempt + 1}/{self.max_retries}")
            time.sleep(2 ** attempt)  # 指数退避

        return response

# 使用拦截器
channel = grpc.intercept_channel(
    grpc.insecure_channel('localhost:50051'),
    LoggingInterceptor(),
    RetryInterceptor(),
)

Go 客户端拦截器

// 日志拦截器
func loggingInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    log.Printf("Request: %s", method)

    start := time.Now()
    err := invoker(ctx, method, req, reply, cc, opts...)
    duration := time.Since(start)

    if err != nil {
        log.Printf("Error: %s (%v) - %v", method, duration, err)
    } else {
        log.Printf("Response: %s (%v)", method, duration)
    }

    return err
}

// 重试拦截器
func retryInterceptor(maxRetries int) grpc.UnaryClientInterceptor {
    return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
        var lastErr error

        for attempt := 0; attempt < maxRetries; attempt++ {
            err := invoker(ctx, method, req, reply, cc, opts...)

            if err == nil {
                return nil
            }

            lastErr = err
            log.Printf("Retry %d/%d: %v", attempt+1, maxRetries, err)

            // 指数退避
            time.Sleep(time.Duration(1<<uint(attempt)) * time.Second)
        }

        return lastErr
    }
}

// 使用拦截器
conn, err := grpc.Dial(
    "localhost:50051",
    grpc.WithTransportCredentials(insecure.NewCredentials()),
    grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
        loggingInterceptor,
        retryInterceptor(3),
    )),
)

中间件库

grpc-go-middleware

import (
    "github.com/grpc-ecosystem/go-grpc-middleware"
    "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
    "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
    "github.com/grpc-ecosystem/go-grpc-middleware/ratelimit"
)

// 创建服务器
server := grpc.NewServer(
    grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
        grpc_recovery.UnaryServerInterceptor(),  // 恢复 panic
        grpc_zap.UnaryServerInterceptor(logger),  // 日志
        grpc_ratelimit.UnaryServerInterceptor(limiter),  // 限流
    )),
)

grpc-interceptor

from grpc_interceptor import ServerInterceptor

class LoggingInterceptor(ServerInterceptor):
    def intercept(self, method, request, context, method_name):
        logging.info(f"Request: {method_name}")

        try:
            response = method(request, context)
            logging.info(f"Response: {method_name}")
            return response
        except Exception as e:
            logging.error(f"Error: {method_name} - {e}")
            raise

小结

拦截器与中间件要点:

  • 服务端拦截器:日志、认证、限流
  • 客户端拦截器:日志、重试
  • 中间件库:grpc-go-middleware、grpc-interceptor
  • 链式拦截器:组合多个拦截器

下一章我们将学习生产实践。