找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

2703

积分

1

好友

371

主题
发表于 前天 09:34 | 查看: 9| 回复: 0

深度学习的广阔领域中,分类问题是最常见且基础的任务之一。无论是识别图片中的物体,还是判断一段文本的情感,分类算法都扮演着核心角色。作为入门深度学习的必经之路,softmax回归提供了一个强大而直观的多分类解决方案。本文将带你从数学原理出发,结合经典的MNIST手写数字识别任务,使用Python的PyTorch框架,一步步实现并理解softmax回归。

什么是softmax回归?

基本概念

Softmax回归,也称为多类别逻辑回归(Multinomial Logistic Regression),可以看作是线性回归在多分类问题上的自然扩展。它的核心目标是为一个输入样本预测其属于多个离散类别中每一个的概率。

数学原理

假设一个模型(例如神经网络的最后一层)为输入样本计算出了 C 个类别的原始得分(logits),记为 z₁, z₂, ..., z_C。Softmax函数将这些原始得分转换为一个概率分布。对于第 i 个类别,其概率 p_i 的计算公式如下:

p_i = exp(z_i) / Σ_{j=1}^{C} exp(z_j)

该函数的特性保证了转换后的输出满足概率分布的所有条件:每个类别的概率值在 01 之间,并且所有类别的概率之和为 1。这使得模型的输出具有直接的可解释性。

特点与优势

  1. 输出可解释性强:直接输出每个类别的概率,便于理解模型对预测结果的置信度。
  2. 梯度计算友好:当与交叉熵损失函数(Cross-Entropy Loss)结合使用时,其梯度形式简单,有利于高效的反向传播。
  3. 泛化性好:作为基础的分类器,适用于广泛的多分类场景。
  4. 计算效率高:相比更复杂的非线性分类模型,其计算开销相对较小。

softmax回归在深度学习中的作用

作为分类问题的基础

在复杂的深度神经网络中,softmax层通常被放置在网络的最后一层。它的职责是将前面所有层学习到的高维特征表示,映射为最终的分类概率分布,为模型赋予解决多分类问题的能力。

与损失函数的配合

Softmax回归通常与交叉熵损失函数配合使用,构成一个经典的分类模块。交叉熵损失衡量了模型预测的概率分布与真实的“one-hot”标签分布之间的差异,为模型参数的优化提供了明确的方向。

应用场景

  • 图像分类:如MNIST手写数字识别、CIFAR-10物体识别等。
  • 自然语言处理:文本分类、情感分析、命名实体识别等。
  • 推荐系统:用户兴趣分类、物品类别预测等。
  • 医疗诊断:疾病分类、症状分析等。

实战案例:MNIST手写数字识别

让我们通过一个经典的实战项目——MNIST手写数字识别,来完整地走一遍softmax回归的建模流程。

数据集介绍

MNIST数据集是机器学习领域的“Hello World”,它包含了大量0到9的手写数字图片:

  • 训练集:60,000张28×28像素的灰度图像。
  • 测试集:10,000张28×28像素的灰度图像。
  • 每张图片都对应一个0-9的数字标签。

模型构建

我们将使用PyTorch框架来实现整个softmax回归模型。

4.2.1. 数据准备

首先,我们需要导入必要的库并准备数据。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子保证可重复性
torch.manual_seed(42)

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化数据
])

# 加载MNIST数据集
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

4.2.2. 定义模型

接下来,我们定义SoftmaxRegression模型类。

class SoftmaxRegression(nn.Module):
    def __init__(self):
        super(SoftmaxRegression, self).__init__()
        # 将28x28的图像展平为784维向量
        self.flatten = nn.Flatten()
        # 全连接层,输入784,输出10(10个数字类别)
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        x = self.flatten(x)
        # 注意:这里没有显式调用softmax。
        # 因为nn.CrossEntropyLoss内部已经结合了log_softmax,所以模型直接输出logits即可。
        return self.linear(x)

# 初始化模型
model = SoftmaxRegression()
print(model)

4.2.3. 定义损失函数和优化器

我们使用交叉熵损失和随机梯度下降(SGD)优化器。

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

4.2.4. 训练模型

定义训练函数,进行多轮训练。

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # 前向传播
        output = model(data)
        loss = criterion(output, target)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

4.2.5. 测试模型

定义测试函数,评估模型在测试集上的性能。

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # 将一批次的损失相加
            pred = output.argmax(dim=1, keepdim=True)  # 获取概率最大的类别
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')

    return accuracy

# 训练和测试模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

test_accuracies = []
epochs = 10

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    accuracy = test(model, device, test_loader, criterion)
    test_accuracies.append(accuracy)

4.2.6. 可视化结果

最后,我们可视化训练过程中的准确率变化和一些预测样本。

plt.figure(figsize=(12, 4))

# 绘制测试准确率变化
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), test_accuracies)
plt.title('Test Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.grid(True)

# 绘制一些测试样本及其预测结果
plt.subplot(1, 2, 2)
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
example_data = example_data.to(device)

with torch.no_grad():
    output = model(example_data)
    pred = output.argmax(dim=1, keepdim=True)

# 显示一些测试图像及其预测
for i in range(6):
    plt.subplot(2, 3, i+1)
    plt.imshow(example_data[i].cpu().squeeze(), cmap='gray', interpolation='none')
    plt.title(f"Pred: {pred[i].item()}, True: {example_targets[i].item()}")
    plt.xticks([])
    plt.yticks([])

plt.tight_layout()
plt.savefig('mnist_softmax_results.png')
plt.show()

结果展示

运行上述代码后,我们将得到模型性能指标和可视化的预测结果。

MNIST手写数字识别softmax回归模型预测结果示例

从运行结果中,我们可以观察到以下几点:

  1. 训练过程:模型在10个训练周期内,损失函数值稳步下降,表明模型正在有效学习。
  2. 测试准确率:最终模型在测试集上的准确率大约能达到92%,这对于一个简单的线性模型来说是非常不错的结果。
  3. 预测示例:右侧的可视化图展示了模型对部分测试样本的预测情况,可以看到绝大多数数字都被正确识别了。

结果分析

这个实战案例清晰地展示了softmax回归的有效性:

  1. 即便是一个简单的线性模型(单层全连接),softmax回归也能在MNIST这类相对规整的数据集上取得很高的分类准确率。
  2. 它验证了数据预处理(如图像标准化)对于稳定和加速模型训练的重要性。
  3. 整个流程(数据加载、模型定义、训练循环、评估测试)构成了一个完整的机器学习项目模板,是深入学习更复杂模型的基础。

通过这个从理论到实践的完整过程,相信你已经对softmax回归有了扎实的理解。它不仅是深度学习的入门基石,其背后将模型输出转换为概率分布的思想,也贯穿于许多更先进的模型之中。希望这篇教程能帮助你在技术社区(如云栈社区)的探索之路上更进一步。




上一篇:Oracle 2026年Q1补丁发布解读:19c、21c、26ai等版本更新详情
下一篇:Griptape滑板流媒体平台发布:滑板视频专享、动作识别与收益分配解析
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2026-1-24 00:28 , Processed in 0.338155 second(s), 40 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2026 云栈社区.

快速回复 返回顶部 返回列表