本篇内容旨在帮助刚接触深度学习的读者,快速上手 PyTorch 的核心数据结构——张量(Tensor),并掌握如何高效选择与管理计算设备。你将了解从环境搭建、基础操作到性能优化的关键知识点,并配有可直接运行的代码示例。
🛠️ 环境准备
首先,我们通过以下命令创建一个虚拟环境并安装 PyTorch:
python -m venv .venv && source .venv/bin/activate
pip install torch torchvision
🧠 张量基础与设备选择
张量是 PyTorch 中的核心数据结构,它是一个多维数组,与 NumPy 的 ndarray 类似,但可以驻留在 GPU 等设备上以加速计算。正确地管理设备是高效利用硬件资源的第一步。
import torch
# 自动选择设备:优先 MPS (Apple 芯片) 或 CUDA (NVIDIA),否则回退到 CPU
device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
# 在选定设备上创建张量
x = torch.randn(3, 4, device=device)
y = torch.randn(4, 5, device=device)
# 执行矩阵乘法(@ 运算符)
z = x @ y
print("设备:", device)
print("x形状/类型:", x.shape, x.dtype)
print("z形状/类型:", z.shape, z.dtype)
核心要点:
- 设备自动选择:代码优先尝试使用 MPS(Apple 芯片)或 CUDA(NVIDIA GPU)加速,否则回退到 CPU。
- 常用操作:矩阵乘(
@)、逐元素加减乘除、以及维度/形状变换(view/reshape/permute)是张量最基础的操作。
📚 基础知识:线性代数与张量维度
维度、形状与步幅
- 向量与矩阵:向量是 1D 张量(形状如
[n]),矩阵是 2D 张量(形状如 [m, n])。更高维张量可理解为多维数组,例如图像数据常用 NCHW(批数、通道、高、宽)格式。
- 形状:决定了数据在各维度上的长度,是张量运算的首要约束。例如矩阵乘法要求前者的列数等于后者的行数。
- 步幅:描述在内存中访问相邻元素所需的跨度。调用
contiguous() 可以确保张量在内存中连续存储,这对某些需要连续内存的内核操作是必要的。
- 数据类型:常用的有
torch.float32(训练默认精度)、torch.float16/bfloat16(混合精度训练以节省显存)、torch.int64(常用于索引)。选择合适的 dtype 对计算精度和性能有直接影响。
线性运算规则速览
| 运算 |
PyTorch 表达 |
形状规则 |
| 转置 |
A.T 或 A.transpose(0, 1) |
[m, n] 变为 [n, m] |
| 点积 |
a @ b |
a.shape=[n]、b.shape=[n] |
| 矩阵乘 |
A @ B |
[m, k] @ [k, n] 得到 [m, n] |
理解本质:神经网络的线性层(如全连接层)本质上就是线性变换(矩阵乘法)与非线性激活函数的组合。理解并检查张量的 shape 是避免维度错误的关键。
🔬 实战进阶:基础数值操作与广播
import torch
device = torch.device("cpu")
a = torch.ones(2, 3, device=device) # [2, 3]
b = torch.randn(1, 3, device=device) # [1, 3]
c = a + b # 维度自动广播
print(c)
广播机制
- 机制:当两个张量的维度兼容时,PyTorch 会自动将维度较小的张量在长度为 1 的维度上“虚拟”复制,以匹配较大张量的形状,从而简化运算。
- 规则概要:从最后一个维度开始向前比较;如果两个维度相等,或其中一个为 1,则可以广播;否则不兼容。
- 对齐技巧:使用
unsqueeze 方法可以在指定位置添加一个大小为 1 的维度,以便于手动对齐。
import torch
x = torch.randn(4, 3) # [4, 3]
w = torch.randn(3) # [3]
y = x + w # w 隐式广播为 [1, 3],然后为 [4, 3]
# 显式对齐,意图更清晰
u = w.unsqueeze(0) # [1, 3]
y2 = x + u # 同样成立
实用建议:在高维或复杂运算中,主动使用 unsqueeze 和 expand 明确广播意图,可以降低隐式错误的风险,并使代码更易读。
🌐 设备管理与内存拷贝
有效的设备管理是 云原生 与高性能计算场景下的重要技能。
- 设备可用性检测:使用
torch.cuda.is_available()、torch.backends.mps.is_available()。
- 张量迁移:使用
x.to(device) 迁移张量,或在创建时直接指定 device 参数。注意:不同设备上的张量不能直接进行运算。
- 性能注意:频繁的设备间数据传输会带来显著开销,应尽量批量传输数据,并避免在循环中进行小张量的
to(device) 操作。
布局与内存:NCHW vs. NHWC
- 图像张量常用
NCHW(批数、通道、高、宽)格式存储。在某些后端(如 cuDNN)上,使用 NHWC(又称 channels_last)内存格式可以提升卷积等操作的效率。
- 连续性:部分操作(如某些视图操作或底层内核调用)要求张量在内存中是连续的。若不确定,可使用
x = x.contiguous() 创建一个连续的副本(注意这是拷贝操作)。
import torch
img = torch.randn(8, 3, 224, 224) # NCHW 格式
img_cl = img.to(memory_format=torch.channels_last) # 转换为 channels_last 格式
assert img_cl.is_contiguous(memory_format=torch.channels_last)
🔢 数值稳定性与精度
- 随机性与可复现:使用
torch.manual_seed(42) 设定全局随机种子,并在数据加载、数据增强等环节固定随机源,以确保实验可复现。
- 溢出与下溢:在使用
float16 等低精度格式时更易发生。可采用 bfloat16 或配合 GradScaler 的混合精度训练来缓解。
- NaN 与 Inf:常出现在不恰当的除法、对数或指数运算中。训练过程中应监控损失和梯度值,必要时进行梯度裁剪或数值规范化。
视图与拷贝:view/clone 的区别
- 视图:
view、reshape、flatten 等方法在可能时返回一个与原数据共享底层存储的新张量视图。对视图的修改会影响原始张量。
- 拷贝:
clone() 方法会显式创建数据的完整副本,拥有独立的存储空间。在需要避免原地修改影响源数据时使用。
import torch
x = torch.arange(12).reshape(3, 4)
v = x.view(-1) # 视图,与 x 共享存储
y = x.clone() # 拷贝,拥有独立存储
原地操作(In-place)与风险
- 原地操作通常以方法名后加下划线
_ 表示(如 relu_),它能减少内存占用,但可能会破坏计算图、影响梯度计算或导致后续视图不一致。
- 建议:在模型训练阶段谨慎使用原地操作。若出现梯度异常或形状问题,优先考虑将其改为非原地版本进行排查。
⚠️ 常见问题与解决方案
| 问题 |
描述 |
解决方案 |
| 设备不一致 |
在不同设备上的张量参与运算会报错。 |
确保所有参与运算的张量都通过 .to(device) 转移到同一设备。 |
| 形状不匹配 |
矩阵乘、张量拼接等操作有严格的维度规则。 |
仔细检查各张量的 shape,使用 view/reshape 或 unsqueeze 进行调整。 |
| GPU 计时不准 |
GPU/MPS 操作默认异步执行,直接计时会不准确。 |
在计时代码前后调用 torch.cuda.synchronize(),或使用更专业的 torch.profiler。 |
⚙️ 张量创建与常用 API 速览
| 类别 |
常用 API |
| 创建 |
torch.zeros / ones / full / rand / randn / eye / arange / linspace |
| 属性 |
x.shape / x.dtype / x.device / x.is_contiguous() |
| 变换 |
permute / transpose / view / reshape / flatten / squeeze / unsqueeze |
| 拼接 |
cat / stack (cat 在现有维度拼接,stack 会创建新维度) |
import torch
x = torch.linspace(0, 1, steps=5) # 创建包含5个元素的等差数列张量 [5]
mat = torch.eye(3) # 创建 3x3 单位矩阵
z = torch.stack([x, x], dim=0) # 沿新维度(dim=0)堆叠,得到 [2, 5]
📈 进阶:索引、切片与掩码
- 基础索引:如
x[i]、x[:, k] 等,通常返回原张量的一个视图。
- 高级索引:使用张量或列表进行索引(如
x[[0, 2, 4]]),这会返回数据的一个拷贝。
- 布尔掩码:用于根据条件过滤或替换元素。掩码的形状需要与目标张量兼容。
import torch
x = torch.randn(5)
mask = x > 0
x[mask] = 0.0 # 使用布尔掩码将大于0的元素置零
进阶:高阶张量算子
- 爱因斯坦求和(einsum):提供了一种紧凑而强大的方式来表达复杂的张量乘法与归约操作。
A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ik,kj->ij', A, B) # 等价于矩阵乘法 A @ B
- 批矩阵乘法(bmm):专门用于批量处理矩阵乘法,形状规则为
[b, m, k] @ [b, k, n] -> [b, m, n]。
进阶:索引高级技巧(聚合与散射)
这些操作在实现特定算法(如词袋模型、图神经网络)时非常高效。
- 聚合(indexadd):按给定的索引将源张量的值累加到目标张量的指定位置上。
idx = torch.tensor([0, 2, 1, 2]) # 索引
src = torch.tensor([1., 2., 3., 4.]) # 源数据
out = torch.zeros(3) # 目标张量
out.index_add_(0, idx, src) # out 变为 [1., 3., 6.]
- 散射(scatter)与采样(gather):分别用于按索引写入/累加和按索引提取,常见于注意力机制等需要对齐操作的场景。
参考资料与下一步