在上一篇文章中,我们探讨了智能水果识别系统的数据模块。本文将聚焦于模型模块,这是深度学习项目中将数据转化为识别能力的核心引擎。通过模型工厂类与加载器的设计,我们可以构建清晰、可维护的系统架构。
一、模型模块
模型工厂类
在智能水果识别系统中,我们创建了 models/model_factory.py 模块作为模型工厂类,用于统一创建不同类型的预训练深度学习模型。这种设计模式在Python项目中能有效提升代码的模块化与可维护性。
以下是 models/model_factory.py 的核心代码:
import torch.nn as nn
from torchvision import models
from config.settings import Config
from utils.logger import logger
class ModelFactory:
"""模型工厂类,用于创建不同类型的深度学习模型"""
@staticmethod
def create_model(model_name=None, num_classes=None, pretrained=True, freeze_backbone=False):
"""
创建模型
Args:
model_name: 模型名称
num_classes: 类别数量
pretrained: 是否使用预训练权重
freeze_backbone: 是否冻结主干网络
Returns:
创建的模型
"""
# 使用配置中的默认值或传入的参数
model_name = model_name or Config.MODEL_NAME
num_classes = num_classes or Config.NUM_CLASSES
logger.info(f"创建模型: {model_name}, 类别数: {num_classes}, 预训练: {pretrained}")
# 根据模型名称创建不同的模型
if model_name == 'efficientnet_b0':
# 创建EfficientNet-B0模型
model = models.efficientnet_b0(weights='DEFAULT' if pretrained else None)
# 如果需要冻结主干网络,则冻结特征提取部分的参数
if freeze_backbone:
for param in model.features.parameters():
param.requires_grad = False
# 替换分类器以适应指定的类别数
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=Config.DROPOUT_RATE, inplace=True),
nn.Linear(num_features, num_classes)
)
elif model_name == 'resnet50':
# 创建ResNet-50模型
model = models.resnet50(weights='DEFAULT' if pretrained else None)
# 如果需要冻结主干网络,则冻结所有参数
if freeze_backbone:
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层以适应指定的类别数
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
elif model_name == 'mobilenet_v2':
# 创建MobileNet-V2模型
model = models.mobilenet_v2(weights='DEFAULT' if pretrained else None)
# 如果需要冻结主干网络,则冻结特征提取部分的参数
if freeze_backbone:
for param in model.features.parameters():
param.requires_grad = False
# 替换分类器以适应指定的类别数
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(num_features, num_classes)
)
elif model_name == 'densenet121':
# 创建DenseNet-121模型
model = models.densenet121(weights='DEFAULT' if pretrained else None)
# 如果需要冻结主干网络,则冻结特征提取部分的参数
if freeze_backbone:
for param in model.features.parameters():
param.requires_grad = False
# 替换分类器以适应指定的类别数
num_features = model.classifier.in_features
model.classifier = nn.Linear(num_features, num_classes)
else:
raise ValueError(f"不支持的模型名称: {model_name}")
# 将模型移到指定设备(CPU或GPU)
model = model.to(Config.DEVICE)
logger.info(f"模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}")
return model
@staticmethod
def count_parameters(model):
"""计算模型参数数量
Args:
model: PyTorch模型对象
Returns:
tuple: (总参数量, 可训练参数量)
"""
# 计算模型的总参数量
total_params = sum(p.numel() for p in model.parameters())
# 计算可训练的参数量(未冻结的参数)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"总参数量: {total_params:,}")
logger.info(f"可训练参数量: {trainable_params:,}")
logger.info(f"冻结参数量: {total_params - trainable_params:,}")
return total_params, trainable_params
该模块定义了一个 ModelFactory 类,支持创建多种预训练模型(如 EfficientNet-B0、ResNet-50 等),允许自定义类别数量、预训练权重使用以及主干网络冻结,并提供了模型参数统计功能,实现了模型创建的标准化与中心化管理。
模型加载器类
为了处理模型的保存与加载,我们设计了 models/model_loader.py 模块作为模型加载器。这确保了训练与推理阶段模型权重的无缝管理。
以下是 models/model_loader.py 的代码实现:
import os
import torch
from config.settings import Config
from config.paths import PathManager
from utils.logger import logger
from .model_factory import ModelFactory
class ModelLoader:
"""模型加载器,负责模型的保存和加载操作"""
@staticmethod
def load_model(model_path=None, model_name=None, num_classes=None):
"""
加载训练好的模型
Args:
model_path: 模型文件路径,如果为None则使用默认路径
model_name: 模型名称,用于创建对应架构的模型
num_classes: 类别数量,如果为None则自动从数据集中获取
Returns:
tuple: (加载的模型, 类别名称列表)
"""
# 如果没有指定模型路径,则使用配置中的默认路径
if model_path is None:
model_path = PathManager.get_model_path('best')
# 检查模型文件是否存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
logger.info(f"加载模型: {model_path}")
# 获取类别信息,用于构建模型输出层
from data.dataset import FruitDataset
dataset = FruitDataset()
class_names = dataset.get_class_names()
num_classes = num_classes or len(class_names)
# 创建模型架构(不使用预训练权重,因为我们即将加载自己的权重)
model = ModelFactory.create_model(
model_name=model_name,
num_classes=num_classes,
pretrained=False # 加载自己的权重,不需要预训练
)
# 加载模型权重到模型中,并将其移动到指定设备
model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE))
model.eval() # 设置为评估模式
logger.info(f"模型加载完成,设备: {Config.DEVICE}")
return model, class_names
@staticmethod
def save_model(model, model_name='best', metadata=None):
"""
保存训练好的模型权重
Args:
model: 要保存的模型
model_name: 模型名称 ('best', 'final', 或自定义名称)
metadata: 要保存的元数据(如训练配置、准确率等)
"""
# 获取模型保存路径
model_path = PathManager.get_model_path(model_name)
# 保存模型权重(仅保存state_dict,减小文件大小)
torch.save(model.state_dict(), model_path)
# 如果提供了元数据,则单独保存元数据文件
if metadata is not None:
metadata_path = model_path.replace('.pth', '_metadata.pth')
torch.save(metadata, metadata_path)
logger.info(f"模型已保存: {model_path}")
return model_path
@staticmethod
def save_checkpoint(epoch, model, optimizer, scheduler, best_acc, filename='checkpoint.pth'):
"""
保存训练检查点,用于恢复训练
Args:
epoch: 当前训练轮次
model: 模型对象
optimizer: 优化器对象
scheduler: 学习率调度器对象
best_acc: 当前最佳准确率
filename: 检查点文件名
"""
# 构建检查点字典,包含恢复训练所需的所有信息
checkpoint = {
'epoch': epoch, # 当前训练轮次
'model_state_dict': model.state_dict(), # 模型权重
'optimizer_state_dict': optimizer.state_dict(), # 优化器状态
'scheduler_state_dict': scheduler.state_dict() if scheduler else None, # 调度器状态
'best_acc': best_acc # 最佳准确率
}
# 保存检查点到指定路径
checkpoint_path = os.path.join(Config.SAVE_DIR, filename)
torch.save(checkpoint, checkpoint_path)
logger.info(f"检查点已保存: {checkpoint_path}")
return checkpoint_path
@staticmethod
def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None):
"""
加载训练检查点,用于恢复训练
Args:
checkpoint_path: 检查点文件路径
model: 模型对象
optimizer: 优化器对象(可选)
scheduler: 学习率调度器对象(可选)
Returns:
tuple: (加载的epoch, 最佳准确率)
"""
# 检查检查点文件是否存在
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"检查点文件不存在: {checkpoint_path}")
# 加载检查点文件
checkpoint = torch.load(checkpoint_path, map_location=Config.DEVICE)
# 加载模型权重
model.load_state_dict(checkpoint['model_state_dict'])
# 如果提供了优化器且检查点中包含优化器状态,则加载优化器状态
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 如果提供了调度器且检查点中包含调度器状态,则加载调度器状态
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# 获取检查点中的训练信息
epoch = checkpoint.get('epoch', 0)
best_acc = checkpoint.get('best_acc', 0.0)
logger.info(f"检查点已加载: {checkpoint_path}, epoch: {epoch}, best_acc: {best_acc:.2f}%")
return epoch, best_acc
该模块定义了 ModelLoader 类,负责模型的加载、保存和检查点管理。它支持加载预训练模型、保存模型权重及元数据,并能处理训练中断后的恢复,确保了模型生命周期的高效管理。
二、总结
在智能水果识别系统中,模型模块通过工厂模式与加载器的分层设计,实现了模型创建、加载与应用的解耦,显著提升了系统的灵活性与可维护性:
- 标准化接口:统一的模型创建与加载接口屏蔽了底层框架差异,为训练和推理提供了稳定一致的调用方式。
- 中心化管理:模型工厂集中处理所有模型的构建逻辑,通过配置即可轻松切换不同网络架构,简化了实验流程。
- 高效运作:模型加载器专门负责权重载入、设备部署与优化,确保了推理效率并充分利用硬件资源。
- 易于扩展:模块化设计使得新增模型或替换骨干网络变得简单,系统核心代码无需改动即可集成新的研究成果。
模型是智能系统的核心引擎。一个优秀的模型模块不仅加速了研究迭代,也使从实验到部署的路径更加平滑。本文的设计确保了模型管理的简洁与强大,让开发者能更专注于算法创新与性能优化。