带注释的PyTorch训练循环详解
原标题:The annotated PyTorch training loop
速览
本文提供了一份带有详细注释的PyTorch训练循环代码示例。该指南旨在帮助开发者理解模型训练的核心流程与关键步骤。对于初学者而言,这是掌握PyTorch框架基础的重要参考资料。
AI 深度解读
The annotated PyTorch training loop 深度解读
背景
构建 PyTorch 训练循环在表面上看似简单,但要确保所有组件处于正确的位置并以正确的顺序执行,却往往令人感到意外地脆弱。训练过程中涉及大量动态部分,在修复了最基础错误之后,其余的错误往往难以察觉。如果代码行顺序错乱,训练可能无法收敛、产生错误结果,或者消耗过多的内存。
尽管分布式训练、FSDP(Fully Sharded Data Parallel)和多 GPU 设置是高级话题,但掌握基础训练循环的每一个操作及其潜在陷阱是进行更复杂架构设计的前提。本文旨在通过逐行解析,揭示 PyTorch 训练循环中每个操作的具体作用,以及如果移动这些代码行会发生什么,帮助开发者避免那些不会抛出异常但会导致训练失败的隐蔽错误。
核心内容
1. 完整的训练循环结构
一个标准的 PyTorch 训练循环包含数据加载、模型定义、损失函数与优化器设置,以及核心的训练与验证步骤。
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# --- 数据部分 ---
dataset = TensorDataset(X_train, y_train)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
# --- 模型、损失函数、优化器 ---
model = MLP(in_features=2, hidden=128, out_features=3)
criterion = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=100)
# --- 训练循环 ---
for epoch in range(100):
model.train()
for X_batch, y_batch in loader:
optimiser.zero_grad()
logits = model(X_batch)
loss = criterion(logits, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimiser.step()
scheduler.step()
model.eval()
with torch.no_grad():
val_logits = model(X_val)
val_loss = criterion(val_logits, y_val)
2. 数据管道(Data Pipeline)详解
PyTorch 的数据管道由 Dataset 和 DataLoader 两部分组成。
- Dataset:是一个 Python 对象,实现了
__len__(返回数据集大小)和__getitem__(获取特定索引的数据)。它可以是张量的简单包装,也可以按需从磁盘加载数据。 - DataLoader:包装数据集并生成批次(batches)。遍历整个数据集一次称为一个 epoch。设置
shuffle=True意味着每个 epoch 中样本的呈现顺序都会不同。
DataLoader 关键参数解析
loader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=2,
pin_memory=True,
persistent_workers=True,
)
num_workers:生成独立进程以并行预取批次,与 GPU 计算同时进行。如果设为 0,主进程将承担所有加载工作,这通常在数据密集型任务中成为 GPU 利用率的瓶颈。一般建议设置为 2-4,具体取决于 CPU 核心数和 I/O 速度。pin_memory=True:将批次张量分配在主机上的固定内存(pinned memory)中。这使得 GPU 的 DMA 引擎可以直接从固定内存传输数据,而无需先通过内核缓冲区复制,从而减少主机到设备的传输时间。仅在num_workers > 0且传输到 CUDA 时有效。persistent_workers=True:在 epoch 之间保持工作进程存活。如果不设置,每个 epoch 开始时都会重新生成工作进程,引入 fork 开销,这在大量工作进程时尤为明显。drop_last=True:如果最后一个批次的大小小于batch_size,则丢弃该批次。由于 BatchNorm 统计量基于少量样本(如 2-3 个)计算时会非常嘈杂,丢弃剩余数据有助于提高稳定性,尽管这会损失少量数据。- Batch Size 优化:较小的批次产生更嘈杂的梯度估计,起到隐式正则化的作用;较大的批次占用更多 GPU 内存但允许更多并行性。建议将批次大小和层维度设置为 8 或 16 的倍数,以对齐 Tensor Core 的 tile 大小(通常为 16x16 或 8x16,取决于数据类型)。
3. 设备移动与随机种子
.to(device):将张量移动到目标设备。对于张量,这不是原地操作:它返回一个新的张量,而原始张量保持不变。例如,X_batch.to('cuda')返回 GPU 上的新张量,而X_batch本身仍留在 CPU 上。- 设置随机种子:在构建模型和加载器之前设置种子,以确保每次运行结果相同,这对于实验复现至关重要。主要影响数据加载器的洗牌顺序和模型权重的初始化。
import random
import numpy as np
def set_seed(seed: int = 42):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
torch.manual_seed:设置 CPU 生成器种子。torch.cuda.manual_seed_all:设置所有 GPU 的种子。- NumPy 和 Python
random:它们是独立的随机数生成器(RNG),PyTorch 不管理它们,因此需要单独设置种子。 cudnn.deterministic = True:强制 cuDNN 使用确定性卷积算法。某些 cuDNN 内核默认非确定性以提高吞吐量。确定性替代方案稍慢,但在开发阶段通常影响不大。cudnn.benchmark = False:必须与deterministic = True配对。当benchmark = True时,cuDNN 会针对每种输入形状配置多种算法并选择最快的一个,这个过程本身在不同运行之间会有所变化。将其固定为False可确保结果一致。- 多工作进程下的随机性:当
num_workers > 0时,每个 DataLoader 工作进程拥有自己的 RNG 状态,由操作系统在 fork 进程时播种。为了复现工作进程的随机性,应传递一个generator和一个worker_init_fn:
g = torch.Generator()
g.manual_seed(42)
loader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=2,
generator=g,
worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id),
)
4. 模型定义(nn.Module)
nn.Module 提供参数跟踪、设备移动、train/eval 模式切换和序列化功能。每个实例需要 __init__(注册所有子模块)和 forward(定义计算)。
class MLP(nn.Module):
def __init__(self, in_features, hidden, out_features):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_features, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, out_features),
)
def forward(self, x):
return self.net(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP(in_features=2, hidden=128, out_features=3).to(device)
super().__init__():这是必需的,它初始化模块注册表。在__init__中将子模块或参数张量分配为属性会自动注册它们。未注册的普通属性将被排除在参数跟踪之外,导致训练
查看原文 →idlemachines.co.uk
