第六章:模型部署¶
部署概述¶
训练好的模型需要部署到生产环境才能实际使用。本章介绍几种常见的部署方式。
部署方式对比¶
| 方式 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|
| Python API | 快速原型 | 简单直接 | 性能一般 |
| ONNX | 跨平台部署 | 兼容性好 | 需要转换 |
| TorchScript | 生产环境 | 性能好 | 调试困难 |
| TensorRT | 高性能推理 | 极快 | 仅限 NVIDIA GPU |
方式一:Python API 部署¶
使用 Flask¶
# app.py
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io
app = Flask(__name__)
# 定义模型
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MNISTNet().to(device)
model.load_state_dict(torch.load('mnist_model.pth', map_location=device))
model.eval()
# 预处理
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
@app.route('/predict', methods=['POST'])
def predict():
"""预测接口"""
if 'image' not in request.files:
return jsonify({'error': '没有上传图片'}), 400
file = request.files['image']
image = Image.open(io.BytesIO(file.read()))
# 预处理
img_tensor = transform(image).unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(img_tensor)
_, predicted = output.max(1)
probability = torch.softmax(output, dim=1)
return jsonify({
'prediction': predicted.item(),
'confidence': probability[0][predicted].item(),
'probabilities': probability[0].tolist()
})
@app.route('/health', methods=['GET'])
def health():
"""健康检查"""
return jsonify({'status': 'healthy'})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
测试 API¶
import requests
# 上传图片预测
with open('test_image.png', 'rb') as f:
response = requests.post(
'http://localhost:5000/predict',
files={'image': f}
)
print(response.json())
# {'prediction': 7, 'confidence': 0.98, 'probabilities': [...]}
使用 Docker 部署¶
# Dockerfile
FROM python:3.10-slim
WORKDIR /app
# 安装依赖
RUN pip install torch torchvision flask pillow --index-url https://download.pytorch.org/whl/cpu
# 复制文件
COPY app.py .
COPY mnist_model.pth .
# 暴露端口
EXPOSE 5000
# 启动服务
CMD ["python", "app.py"]
方式二:ONNX 部署¶
ONNX(Open Neural Network Exchange)是一种开放的模型格式,支持跨框架部署。
导出 ONNX 模型¶
import torch
import torch.nn as nn
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载模型
model = MNISTNet()
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()
# 创建示例输入
dummy_input = torch.randn(1, 1, 28, 28)
# 导出 ONNX
torch.onnx.export(
model,
dummy_input,
"mnist_model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print("ONNX 模型已导出")
使用 ONNX Runtime 推理¶
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms
# 加载 ONNX 模型
session = ort.InferenceSession("mnist_model.onnx")
# 获取输入输出信息
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 预处理
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 预测
def predict_onnx(image_path):
image = Image.open(image_path)
img_tensor = transform(image).unsqueeze(0).numpy()
# ONNX Runtime 推理
outputs = session.run([output_name], {input_name: img_tensor})
predicted = np.argmax(outputs[0], axis=1)[0]
return predicted
print(predict_onnx('test_image.png'))
方式三:TorchScript 部署¶
TorchScript 是 PyTorch 的原生序列化格式,支持 C++ 部署。
转换为 TorchScript¶
import torch
import torch.nn as nn
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载模型
model = MNISTNet()
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()
# 方法1:使用 trace
dummy_input = torch.randn(1, 1, 28, 28)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save("mnist_traced.pt")
# 方法2:使用 script
scripted_model = torch.jit.script(model)
scripted_model.save("mnist_scripted.pt")
print("TorchScript 模型已保存")
加载 TorchScript 模型¶
import torch
# 加载模型
model = torch.jit.load("mnist_traced.pt")
model.eval()
# 预测
with torch.no_grad():
output = model(torch.randn(1, 1, 28, 28))
predicted = output.argmax(dim=1).item()
print(f"预测结果: {predicted}")
C++ 部署¶
// main.cpp
#include <torch/script.h>
#include <iostream>
int main() {
// 加载模型
torch::jit::script::Module model;
try {
model = torch::jit::load("mnist_traced.pt");
} catch (const c10::Error& e) {
std::cerr << "加载模型失败" << std::endl;
return -1;
}
// 创建输入
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::randn({1, 1, 28, 28}));
// 推理
at::Tensor output = model.forward(inputs).toTensor();
int predicted = output.argmax(1).item<int>();
std::cout << "预测结果: " << predicted << std::endl;
return 0;
}
方式四:FastAPI + 异步部署¶
# main.py
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io
import asyncio
app = FastAPI(title="MNIST 识别 API")
# 全局模型
model = None
device = None
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
@app.on_event("startup")
async def startup():
"""启动时加载模型"""
global model, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MNISTNet().to(device)
model.load_state_dict(torch.load('mnist_model.pth', map_location=device))
model.eval()
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""预测接口"""
# 读取图片
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# 预处理
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
img_tensor = transform(image).unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.softmax(output, dim=1)
predicted = output.argmax(dim=1).item()
confidence = probabilities[0][predicted].item()
return JSONResponse({
"prediction": predicted,
"confidence": round(confidence, 4),
"probabilities": [round(p, 4) for p in probabilities[0].tolist()]
})
@app.get("/health")
async def health():
"""健康检查"""
return {"status": "healthy", "device": str(device)}
# 运行: uvicorn main:app --host 0.0.0.0 --port 8000
性能优化¶
1. 批处理推理¶
def batch_predict(model, images, batch_size=32):
"""批量预测"""
model.eval()
predictions = []
with torch.no_grad():
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
output = model(batch)
predictions.extend(output.argmax(dim=1).tolist())
return predictions
2. 半精度推理¶
# 使用 FP16 加速
model = model.half() # 转换为半精度
img_tensor = img_tensor.half()
with torch.no_grad():
output = model(img_tensor)
3. GPU 优化¶
# 使用 CUDA
device = torch.device('cuda')
model = model.to(device)
# 使用多 GPU
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
监控和日志¶
import logging
from datetime import datetime
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filename='inference.log'
)
logger = logging.getLogger(__name__)
def predict_with_logging(image):
start_time = datetime.now()
try:
result = predict(image)
duration = (datetime.now() - start_time).total_seconds()
logger.info(f"预测成功 - 结果: {result}, 耗时: {duration:.3f}s")
return result
except Exception as e:
logger.error(f"预测失败: {str(e)}")
raise
小结¶
本章学习了:
- ✅ Flask/FastAPI 部署模型
- ✅ ONNX 跨平台部署
- ✅ TorchScript 生产部署
- ✅ Docker 容器化部署
- ✅ 性能优化技巧
部署方式选择:
| 场景 | 推荐方式 |
|---|---|
| 快速原型 | Flask/FastAPI |
| 生产环境 | TorchScript + FastAPI |
| 跨平台 | ONNX |
| 高性能 | TensorRT |
PyTorch 教程总结¶
恭喜你完成了 PyTorch 教程!你学到了:
- 第一章:PyTorch 简介和安装
- 第二章:张量操作和常见问题
- 第三章:自动求导机制
- 第四章:神经网络基础
- 第五章:MNIST 实战项目
- 第六章:模型部署
继续学习:Transformer 架构