找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

1561

积分

0

好友

231

主题
发表于 4 小时前 | 查看: 1| 回复: 0

在上一篇文章中,我们探讨了智能水果识别系统的数据模块。本文将聚焦于模型模块,这是深度学习项目中将数据转化为识别能力的核心引擎。通过模型工厂类与加载器的设计,我们可以构建清晰、可维护的系统架构。

一、模型模块

模型工厂类

在智能水果识别系统中,我们创建了 models/model_factory.py 模块作为模型工厂类,用于统一创建不同类型的预训练深度学习模型。这种设计模式在Python项目中能有效提升代码的模块化与可维护性。

以下是 models/model_factory.py 的核心代码:

import torch.nn as nn
from torchvision import models
from config.settings import Config
from utils.logger import logger

class ModelFactory:
    """模型工厂类,用于创建不同类型的深度学习模型"""

    @staticmethod
    def create_model(model_name=None, num_classes=None, pretrained=True, freeze_backbone=False):
        """
        创建模型

        Args:
            model_name: 模型名称
            num_classes: 类别数量
            pretrained: 是否使用预训练权重
            freeze_backbone: 是否冻结主干网络

        Returns:
            创建的模型
        """
        # 使用配置中的默认值或传入的参数
        model_name = model_name or Config.MODEL_NAME
        num_classes = num_classes or Config.NUM_CLASSES

        logger.info(f"创建模型: {model_name}, 类别数: {num_classes}, 预训练: {pretrained}")

        # 根据模型名称创建不同的模型
        if model_name == 'efficientnet_b0':
            # 创建EfficientNet-B0模型
            model = models.efficientnet_b0(weights='DEFAULT' if pretrained else None)
            # 如果需要冻结主干网络,则冻结特征提取部分的参数
            if freeze_backbone:
                for param in model.features.parameters():
                    param.requires_grad = False

            # 替换分类器以适应指定的类别数
            num_features = model.classifier[1].in_features
            model.classifier = nn.Sequential(
                nn.Dropout(p=Config.DROPOUT_RATE, inplace=True),
                nn.Linear(num_features, num_classes)
            )

        elif model_name == 'resnet50':
            # 创建ResNet-50模型
            model = models.resnet50(weights='DEFAULT' if pretrained else None)
            # 如果需要冻结主干网络,则冻结所有参数
            if freeze_backbone:
                for param in model.parameters():
                    param.requires_grad = False

            # 替换最后的全连接层以适应指定的类别数
            num_features = model.fc.in_features
            model.fc = nn.Linear(num_features, num_classes)

        elif model_name == 'mobilenet_v2':
            # 创建MobileNet-V2模型
            model = models.mobilenet_v2(weights='DEFAULT' if pretrained else None)
            # 如果需要冻结主干网络,则冻结特征提取部分的参数
            if freeze_backbone:
                for param in model.features.parameters():
                    param.requires_grad = False

            # 替换分类器以适应指定的类别数
            num_features = model.classifier[1].in_features
            model.classifier = nn.Sequential(
                nn.Dropout(p=0.2),
                nn.Linear(num_features, num_classes)
            )

        elif model_name == 'densenet121':
            # 创建DenseNet-121模型
            model = models.densenet121(weights='DEFAULT' if pretrained else None)
            # 如果需要冻结主干网络,则冻结特征提取部分的参数
            if freeze_backbone:
                for param in model.features.parameters():
                    param.requires_grad = False

            # 替换分类器以适应指定的类别数
            num_features = model.classifier.in_features
            model.classifier = nn.Linear(num_features, num_classes)

        else:
            raise ValueError(f"不支持的模型名称: {model_name}")

        # 将模型移到指定设备(CPU或GPU)
        model = model.to(Config.DEVICE)

        logger.info(f"模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}")

        return model

    @staticmethod
    def count_parameters(model):
        """计算模型参数数量

        Args:
            model: PyTorch模型对象

        Returns:
            tuple: (总参数量, 可训练参数量)
        """
        # 计算模型的总参数量
        total_params = sum(p.numel() for p in model.parameters())
        # 计算可训练的参数量(未冻结的参数)
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

        logger.info(f"总参数量: {total_params:,}")
        logger.info(f"可训练参数量: {trainable_params:,}")
        logger.info(f"冻结参数量: {total_params - trainable_params:,}")

        return total_params, trainable_params

该模块定义了一个 ModelFactory 类,支持创建多种预训练模型(如 EfficientNet-B0、ResNet-50 等),允许自定义类别数量、预训练权重使用以及主干网络冻结,并提供了模型参数统计功能,实现了模型创建的标准化与中心化管理。

模型加载器类

为了处理模型的保存与加载,我们设计了 models/model_loader.py 模块作为模型加载器。这确保了训练与推理阶段模型权重的无缝管理。

以下是 models/model_loader.py 的代码实现:

import os
import torch
from config.settings import Config
from config.paths import PathManager
from utils.logger import logger
from .model_factory import ModelFactory

class ModelLoader:
    """模型加载器,负责模型的保存和加载操作"""

    @staticmethod
    def load_model(model_path=None, model_name=None, num_classes=None):
        """
        加载训练好的模型

        Args:
            model_path: 模型文件路径,如果为None则使用默认路径
            model_name: 模型名称,用于创建对应架构的模型
            num_classes: 类别数量,如果为None则自动从数据集中获取

        Returns:
            tuple: (加载的模型, 类别名称列表)
        """
        # 如果没有指定模型路径,则使用配置中的默认路径
        if model_path is None:
            model_path = PathManager.get_model_path('best')

        # 检查模型文件是否存在
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"模型文件不存在: {model_path}")

        logger.info(f"加载模型: {model_path}")

        # 获取类别信息,用于构建模型输出层
        from data.dataset import FruitDataset
        dataset = FruitDataset()
        class_names = dataset.get_class_names()
        num_classes = num_classes or len(class_names)

        # 创建模型架构(不使用预训练权重,因为我们即将加载自己的权重)
        model = ModelFactory.create_model(
            model_name=model_name,
            num_classes=num_classes,
            pretrained=False  # 加载自己的权重,不需要预训练
        )

        # 加载模型权重到模型中,并将其移动到指定设备
        model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE))
        model.eval()  # 设置为评估模式

        logger.info(f"模型加载完成,设备: {Config.DEVICE}")

        return model, class_names

    @staticmethod
    def save_model(model, model_name='best', metadata=None):
        """
        保存训练好的模型权重

        Args:
            model: 要保存的模型
            model_name: 模型名称 ('best', 'final', 或自定义名称)
            metadata: 要保存的元数据(如训练配置、准确率等)
        """
        # 获取模型保存路径
        model_path = PathManager.get_model_path(model_name)

        # 保存模型权重(仅保存state_dict,减小文件大小)
        torch.save(model.state_dict(), model_path)

        # 如果提供了元数据,则单独保存元数据文件
        if metadata is not None:
            metadata_path = model_path.replace('.pth', '_metadata.pth')
            torch.save(metadata, metadata_path)

        logger.info(f"模型已保存: {model_path}")
        return model_path

    @staticmethod
    def save_checkpoint(epoch, model, optimizer, scheduler, best_acc, filename='checkpoint.pth'):
        """
        保存训练检查点,用于恢复训练

        Args:
            epoch: 当前训练轮次
            model: 模型对象
            optimizer: 优化器对象
            scheduler: 学习率调度器对象
            best_acc: 当前最佳准确率
            filename: 检查点文件名
        """
        # 构建检查点字典,包含恢复训练所需的所有信息
        checkpoint = {
            'epoch': epoch,                      # 当前训练轮次
            'model_state_dict': model.state_dict(),     # 模型权重
            'optimizer_state_dict': optimizer.state_dict(),  # 优化器状态
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,  # 调度器状态
            'best_acc': best_acc                 # 最佳准确率
        }

        # 保存检查点到指定路径
        checkpoint_path = os.path.join(Config.SAVE_DIR, filename)
        torch.save(checkpoint, checkpoint_path)

        logger.info(f"检查点已保存: {checkpoint_path}")
        return checkpoint_path

    @staticmethod
    def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None):
        """
        加载训练检查点,用于恢复训练

        Args:
            checkpoint_path: 检查点文件路径
            model: 模型对象
            optimizer: 优化器对象(可选)
            scheduler: 学习率调度器对象(可选)

        Returns:
            tuple: (加载的epoch, 最佳准确率)
        """
        # 检查检查点文件是否存在
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"检查点文件不存在: {checkpoint_path}")

        # 加载检查点文件
        checkpoint = torch.load(checkpoint_path, map_location=Config.DEVICE)

        # 加载模型权重
        model.load_state_dict(checkpoint['model_state_dict'])

        # 如果提供了优化器且检查点中包含优化器状态,则加载优化器状态
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # 如果提供了调度器且检查点中包含调度器状态,则加载调度器状态
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        # 获取检查点中的训练信息
        epoch = checkpoint.get('epoch', 0)
        best_acc = checkpoint.get('best_acc', 0.0)

        logger.info(f"检查点已加载: {checkpoint_path}, epoch: {epoch}, best_acc: {best_acc:.2f}%")

        return epoch, best_acc

该模块定义了 ModelLoader 类,负责模型的加载、保存和检查点管理。它支持加载预训练模型、保存模型权重及元数据,并能处理训练中断后的恢复,确保了模型生命周期的高效管理。

二、总结

在智能水果识别系统中,模型模块通过工厂模式与加载器的分层设计,实现了模型创建、加载与应用的解耦,显著提升了系统的灵活性与可维护性:

  1. 标准化接口:统一的模型创建与加载接口屏蔽了底层框架差异,为训练和推理提供了稳定一致的调用方式。
  2. 中心化管理:模型工厂集中处理所有模型的构建逻辑,通过配置即可轻松切换不同网络架构,简化了实验流程。
  3. 高效运作:模型加载器专门负责权重载入、设备部署与优化,确保了推理效率并充分利用硬件资源。
  4. 易于扩展:模块化设计使得新增模型或替换骨干网络变得简单,系统核心代码无需改动即可集成新的研究成果。

模型是智能系统的核心引擎。一个优秀的模型模块不仅加速了研究迭代,也使从实验到部署的路径更加平滑。本文的设计确保了模型管理的简洁与强大,让开发者能更专注于算法创新与性能优化。




上一篇:全面解析Unity AnimatorController状态机设计与高级动画应用
下一篇:大语言模型评测全解析:从困惑度到GPT-4评估方法详解
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2025-12-24 18:57 , Processed in 0.289984 second(s), 38 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2025 云栈社区.

快速回复 返回顶部 返回列表