PyTorch SWA
Stochastic Weight Averaging (SWA) 是在比赛中使用率非常高的训练技巧,它可以帮助提高模型的泛化能力。
SWA 的原理非常简单,就是在训练过程中将不同时间切片的模型权重以某种方式来进行平均,从而得到一个新的模型,这个模型可以帮助提高模型的泛化能力。
PyTorch 从 1.6 开始就已经内置了 SWA 以及其相关工具的实现,本文主要针对实践作一些介绍。
推荐的实践方式
一般来说,不会一上来就使用 SWA,而是在训练进行到一定程度后再开始开启 SWA 的训练,或是对某个已经训练好的模型进行 fine-tune 的时候使用。 官方文档中推荐在训练的前 3 / 4 的阶段对模型进行直接训练,然后最后 1 / 4 阶段开启 SWA 训练,当然这个应该根据实际情况做调整。
作者的方法是,先进行一次普通的训练,观察验证集得分曲线,找出过拟合的 epoch,然后以这个 epoch 作为开启 SWA 的起点,重新进行一次训练。
学习率
在 SWA 训练的过程中,学习率也要进行变化,一般会使用一个较大且固定的学习率来训练模型。
作者的理解:较大的学习率在不同时间切片上更容易获得差异性化更大的权重,从而可以跳过局部最优点,提高模型的泛化能力。如果全都是近似的权重,均值的意义就不大了。
通常在 SWA 刚开启时会使用另外一个余弦退火的 Scheduler 来将学习率逐渐调整为一个较大的常数值。假如在常规训练中使用了 OneCycle 的 Scheduler,那么整个训练过程的学习率变化可能是这样的:
Batch Normalization
因为 BN 层在训练过程中会记录其输入的统计量,而当权重进行求均值后,这些统计量就会失效(因为不能简单得直接将这些统计量求均值)。因此如果需要使用 SWA 模型做预测,应该让 SWA 模型过一次训练数据以更新 BN 层的统计量。
而官方文档给的例子中,在整个训练结束后才更新了 BN 层。实际上,我们在训练过程中都会有 Validation 或是 Save Checkpoint 的操作,在做这些操作之前都应该提前更新 BN 层的统计量。
实现
SWA 实现也非常简单,PyTorch 提供了简单的 API 来帮我们实现 SWA 训练:
-
torch.optim.swa_utils.AveragedModel
用于创建 SWA 模型,类方法update_parameters
用于更新 SWA 模型的权重(均值操作),update_bn
用于更新 BN 层的统计量。 -
torch.optim.swa_utils.SWALR
是一个 Scheduler,用于在 SWA 训练的过程中调整学习率。
下面是作者常用的 SWA 的代码片段,涵盖了整片文章的全部内容:
from torch.optim.swa_utils import AveragedModel, SWALR
SWA_START_EPOCH = 80
model = ...
train_loader = ...
val_loader = ...
optimizer = SGD(model.parameters(), lr=0.03)
scheduler = OneCycle(optimizer)
# 创建 SWA 模型和用于 SWA 的 Scheduler,注意 Scheduler 需要仔细设置 swa_lr 和 anneal_epochs 这两个参数
# 与官方的例子不同的是,这里把 anneal_epochs 设置为了一个 epoch 的迭代数,然后 `step()` 的调用是在每个迭代
# 中进行,而不是官方的每个 epoch 完了之后进行
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=1e-5, anneal_epochs=len(train_loader))
for epoch in range(100):
swa_enabled = epoch > SWA_START_EPOCH
_scheduler = swa_scheduler if swa_enabled else scheduler
# 训练原始模型
model.train()
for inputs, targets in train_loader:
loss = compute_loss(model(batch), targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_scheduler.step()
# 更新 SWA 模型权重,并同时也更新 BN
if swa_enabled:
swa_model.update_parameters(model)
swa_model.update_bn(train_loader)
# 验证,如果开启了 SWA 则使用 SWA 模型来做推断
val_model = swa_model if swa_enabled else model
val_model.eval()
loss = 0
for inputs, targets in val_loader:
loss += compute_loss(model(batch), targets)
print(f"val loss: {loss / len(val_loader):.2f})