跳转至

第六章:模型部署

部署概述

训练好的模型需要部署到生产环境才能实际使用。本章介绍几种常见的部署方式。

部署方式对比

方式 适用场景 优点 缺点
Python API 快速原型 简单直接 性能一般
ONNX 跨平台部署 兼容性好 需要转换
TorchScript 生产环境 性能好 调试困难
TensorRT 高性能推理 极快 仅限 NVIDIA GPU

方式一:Python API 部署

使用 Flask

# app.py
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io

app = Flask(__name__)

# 定义模型
class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MNISTNet().to(device)
model.load_state_dict(torch.load('mnist_model.pth', map_location=device))
model.eval()

# 预处理
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

@app.route('/predict', methods=['POST'])
def predict():
    """预测接口"""
    if 'image' not in request.files:
        return jsonify({'error': '没有上传图片'}), 400

    file = request.files['image']
    image = Image.open(io.BytesIO(file.read()))

    # 预处理
    img_tensor = transform(image).unsqueeze(0).to(device)

    # 预测
    with torch.no_grad():
        output = model(img_tensor)
        _, predicted = output.max(1)
        probability = torch.softmax(output, dim=1)

    return jsonify({
        'prediction': predicted.item(),
        'confidence': probability[0][predicted].item(),
        'probabilities': probability[0].tolist()
    })

@app.route('/health', methods=['GET'])
def health():
    """健康检查"""
    return jsonify({'status': 'healthy'})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

测试 API

import requests

# 上传图片预测
with open('test_image.png', 'rb') as f:
    response = requests.post(
        'http://localhost:5000/predict',
        files={'image': f}
    )

print(response.json())
# {'prediction': 7, 'confidence': 0.98, 'probabilities': [...]}

使用 Docker 部署

# Dockerfile
FROM python:3.10-slim

WORKDIR /app

# 安装依赖
RUN pip install torch torchvision flask pillow --index-url https://download.pytorch.org/whl/cpu

# 复制文件
COPY app.py .
COPY mnist_model.pth .

# 暴露端口
EXPOSE 5000

# 启动服务
CMD ["python", "app.py"]
# 构建镜像
docker build -t mnist-api .

# 运行容器
docker run -d -p 5000:5000 mnist-api

方式二:ONNX 部署

ONNX(Open Neural Network Exchange)是一种开放的模型格式,支持跨框架部署。

导出 ONNX 模型

import torch
import torch.nn as nn

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 加载模型
model = MNISTNet()
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()

# 创建示例输入
dummy_input = torch.randn(1, 1, 28, 28)

# 导出 ONNX
torch.onnx.export(
    model,
    dummy_input,
    "mnist_model.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print("ONNX 模型已导出")

使用 ONNX Runtime 推理

import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms

# 加载 ONNX 模型
session = ort.InferenceSession("mnist_model.onnx")

# 获取输入输出信息
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# 预处理
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 预测
def predict_onnx(image_path):
    image = Image.open(image_path)
    img_tensor = transform(image).unsqueeze(0).numpy()

    # ONNX Runtime 推理
    outputs = session.run([output_name], {input_name: img_tensor})

    predicted = np.argmax(outputs[0], axis=1)[0]
    return predicted

print(predict_onnx('test_image.png'))

方式三:TorchScript 部署

TorchScript 是 PyTorch 的原生序列化格式,支持 C++ 部署。

转换为 TorchScript

import torch
import torch.nn as nn

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 加载模型
model = MNISTNet()
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()

# 方法1:使用 trace
dummy_input = torch.randn(1, 1, 28, 28)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save("mnist_traced.pt")

# 方法2:使用 script
scripted_model = torch.jit.script(model)
scripted_model.save("mnist_scripted.pt")

print("TorchScript 模型已保存")

加载 TorchScript 模型

import torch

# 加载模型
model = torch.jit.load("mnist_traced.pt")
model.eval()

# 预测
with torch.no_grad():
    output = model(torch.randn(1, 1, 28, 28))
    predicted = output.argmax(dim=1).item()

print(f"预测结果: {predicted}")

C++ 部署

// main.cpp
#include <torch/script.h>
#include <iostream>

int main() {
    // 加载模型
    torch::jit::script::Module model;
    try {
        model = torch::jit::load("mnist_traced.pt");
    } catch (const c10::Error& e) {
        std::cerr << "加载模型失败" << std::endl;
        return -1;
    }

    // 创建输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::randn({1, 1, 28, 28}));

    // 推理
    at::Tensor output = model.forward(inputs).toTensor();
    int predicted = output.argmax(1).item<int>();

    std::cout << "预测结果: " << predicted << std::endl;

    return 0;
}

方式四:FastAPI + 异步部署

# main.py
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io
import asyncio

app = FastAPI(title="MNIST 识别 API")

# 全局模型
model = None
device = None

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

@app.on_event("startup")
async def startup():
    """启动时加载模型"""
    global model, device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MNISTNet().to(device)
    model.load_state_dict(torch.load('mnist_model.pth', map_location=device))
    model.eval()

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """预测接口"""
    # 读取图片
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))

    # 预处理
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    img_tensor = transform(image).unsqueeze(0).to(device)

    # 预测
    with torch.no_grad():
        output = model(img_tensor)
        probabilities = torch.softmax(output, dim=1)
        predicted = output.argmax(dim=1).item()
        confidence = probabilities[0][predicted].item()

    return JSONResponse({
        "prediction": predicted,
        "confidence": round(confidence, 4),
        "probabilities": [round(p, 4) for p in probabilities[0].tolist()]
    })

@app.get("/health")
async def health():
    """健康检查"""
    return {"status": "healthy", "device": str(device)}

# 运行: uvicorn main:app --host 0.0.0.0 --port 8000

性能优化

1. 批处理推理

def batch_predict(model, images, batch_size=32):
    """批量预测"""
    model.eval()
    predictions = []

    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            output = model(batch)
            predictions.extend(output.argmax(dim=1).tolist())

    return predictions

2. 半精度推理

# 使用 FP16 加速
model = model.half()  # 转换为半精度
img_tensor = img_tensor.half()

with torch.no_grad():
    output = model(img_tensor)

3. GPU 优化

# 使用 CUDA
device = torch.device('cuda')
model = model.to(device)

# 使用多 GPU
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

监控和日志

import logging
from datetime import datetime

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='inference.log'
)

logger = logging.getLogger(__name__)

def predict_with_logging(image):
    start_time = datetime.now()

    try:
        result = predict(image)
        duration = (datetime.now() - start_time).total_seconds()

        logger.info(f"预测成功 - 结果: {result}, 耗时: {duration:.3f}s")
        return result

    except Exception as e:
        logger.error(f"预测失败: {str(e)}")
        raise

小结

本章学习了:

  • ✅ Flask/FastAPI 部署模型
  • ✅ ONNX 跨平台部署
  • ✅ TorchScript 生产部署
  • ✅ Docker 容器化部署
  • ✅ 性能优化技巧

部署方式选择

场景 推荐方式
快速原型 Flask/FastAPI
生产环境 TorchScript + FastAPI
跨平台 ONNX
高性能 TensorRT

PyTorch 教程总结

恭喜你完成了 PyTorch 教程!你学到了:

  1. 第一章:PyTorch 简介和安装
  2. 第二章:张量操作和常见问题
  3. 第三章:自动求导机制
  4. 第四章:神经网络基础
  5. 第五章:MNIST 实战项目
  6. 第六章:模型部署

继续学习:Transformer 架构