美文网首页pytorch
pytorch学习笔记-weight decay 和 learn

pytorch学习笔记-weight decay 和 learn

作者: 升不上三段的大鱼 | 来源:发表于2021-08-17 15:43 被阅读0次

1. Weight decay

Weight decay 是一种正则化方法,大概意思就是在做梯度下降之前,当前模型的 weight 做一定程度的 decay。
weights_{t+1} = (1-weight\_decay)*weight_t - lr * gradient
上面这个就相当于是 weights 减去下面公式对权重的梯度:
\frac{weight\_decay}{2*lr}weight^2 + loss
整理一下就是L2正则化:
loss = loss +\frac{ weight\_decay'}{2} * L_2 (weights)

所以当 weight\_decay' =\frac{weight\_decay}{lr} 的时候,L2正则化和 weight decay 是一样的,因此也会有人说L2正则就是权重衰减。在SGD中的确是这样,但是在 Adam中就不一定了。

使用 weight decay 可以:

  • 防止过拟合
  • 保持权重在一个较小在的值,避免梯度爆炸。因为在原本的 loss 函数上加上了权重值的 L2 范数,在每次迭代时,模不仅会去优化/最小化 loss,还会使模型权重最小化。让权重值保持尽可能小,有利于控制权重值的变化幅度(如果梯度很大,说明模型本身在变化很大,去过拟合样本),从而避免梯度爆炸。

在 pytorch 里可以设置 weight decay。torch.optim.Optimizer 里, SGD、ASGD 、Adam、RMSprop 等都有weight_decay参数设置:

optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)

参考:
Deep learning basic-weight decay
关于量化训练的一个小tip: weight-decay

2. Learning rate decay

知道梯度下降的,应该都知道学习率的影响,过大过小都会影响到学习的效果。Learning rate decay 的目的是在训练过程中逐渐降低学习率,pytorch 在torch.optim.lr_scheduler 里提供了很多花样。

Scheduler 的定义在 optimizer之后, 而参数更新应该在一个 epoch 结束之后。

optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', verbose=True)

for epoch in range(10):
   for input,label in dataloader:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

权重衰减(weight decay)与学习率衰减(learning rate decay)

相关文章

网友评论

    本文标题:pytorch学习笔记-weight decay 和 learn

    本文链接:https://www.haomeiwen.com/subject/gebbbltx.html