您的当前位置:首页优化器和将优化器加入CIFAR10神经网络
优化器和将优化器加入CIFAR10神经网络
来源:锐游网
1.什么是优化器
优化器的作用是在深度学习模型的训练过程中调整模型的参数(权重和偏置),以最小化损失函数(cost function),从而使模型更好地拟合训练数据。
在训练的每一步,优化器会执行以下操作:
计算梯度:根据当前的模型参数和训练数据,计算损失函数对每个参数的梯度。
更新参数:根据梯度和学习率,使用优化算法的规则更新模型的参数。
重复:多次重复这个过程,直到模型的损失达到最低或训练达到了预定的迭代次数。
2.常见优化器
1.SGD(随机梯度下降)
非常传统的优化器,最基本的优化算法,更新过程简单但容易陷入局部最优。
优点:简单易用,计算开销小。
缺点:由于只用一个样本(或小批量)计算梯度,容易陷入局部极小值,且在训练后期震荡较大。
适用场景:适合于简单的网络结构或作为复杂优化算法的基础。
# model.parameters():将模型的所有参数传递给优化器。
# lr:learning rate,学习速率,建议小数点后面两个0
torch.optim.SGD(model.parameters() , lr)
2.Adam
神,Adam 结合了 RMSProp 和动量法的思想,通过自适应调整学习率,同时维护了动量的一阶和二阶矩估计。
- 优点:适应性强,对不同类型的数据表现良好;对超参数不太敏感,默认参数通常效果不错。
- 缺点:有可能会陷入局部最优且收敛慢,特别是在大规模数据集上。
- 适用场景:广泛应用于各种深度学习任务,是目前最常用的优化器之一。
optimizer = torch.optim.Adam(z.parameters() , lr = 0.001)
3.AdamW
神中神,AdamW 是 Adam 的一种改进版本,在每次参数更新时加入了权重衰减(weight decay)来防止模型过拟合。
- 优点:与 L2 正则化相比效果更好,能更好地防止过拟合。
- 缺点:相较 Adam 略复杂,但总体使用体验较好。
- 适用场景:常用于 Transformer 等需要精细调参的深度学习模型。
3.将优化器加入CIFAR10神经网络
1.使用SGD
import torch
from torch import nn
from torch.nn import L1Loss,MSELoss,Sequential,Conv2d,MaxPool2d,Linear,Flatten
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
class Zilliax(nn.Module):
def __init__(self):
super(Zilliax , self).__init__()
self.model = Sequential(
Conv2d(3,32, 5, padding = 2),
MaxPool2d(2),
Conv2d(32 , 32 , 5 , padding = 2),
MaxPool2d(2),
Conv2d(32 , 64 , 5 , padding = 2),
MaxPool2d(2),
Flatten(),
Linear(1024 , 64),
Linear(64 , 10),
)
def forward(self , x):
x = self.model(x)
return x
dataset = torchvision.datasets.CIFAR10('E:\\PyCharm_Project\\Pytorch_2.3.1\\PytorchVision\\dataset',train = False , transform = torchvision.transforms.ToTensor(),download = True )
dataloader = DataLoader(dataset , batch_size = 64)
loss = nn.CrossEntropyLoss()
z = Zilliax()
# 随机梯度下降
"""
torch.optim.SGD(parameters() , lr , )
lr是learning rate,设置太大,模型训练并不稳定,设置太小,模型又比较慢,一般开始大参数,后面小参数,在SGD里建议是小数点后面1个0就可以
"""
optimizer = torch.optim.SGD(z.parameters() ,lr = 0.01)
# Adam
"""
Adam 结合了 RMSProp 和动量法的思想,通过自适应调整学习率,同时维护了动量的一阶和二阶矩估计。
优点:适应性强,对不同类型的数据表现良好;对超参数不太敏感,默认参数通常效果不错。
缺点:有可能会陷入局部最优且收敛慢,特别是在大规模数据集上。
适用场景:广泛应用于各种深度学习任务,是目前最常用的优化器之一。
lr建议是小数点后2个0,lr = 0.01时,loss到epoch16就开始不稳定上升
"""
# optimizer = optim.Adam(z.parameters() , lr = 0.001)
for epoch in range(20):
running_loss = 0.0
for data in dataloader:
imgs,targets = data
outputs = z(imgs)
result = loss(outputs , targets) # 交叉熵自带softmax
optimizer.zero_grad() # 重新设置梯度,清空上一次存入的梯度
result.backward()
optimizer.step()
running_loss = running_loss + result
print(running_loss)
2.Adam
import torch
from torch import nn
from torch.nn import L1Loss,MSELoss,Sequential,Conv2d,MaxPool2d,Linear,Flatten
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
class Zilliax(nn.Module):
def __init__(self):
super(Zilliax , self).__init__()
self.model = Sequential(
Conv2d(3,32, 5, padding = 2),
MaxPool2d(2),
Conv2d(32 , 32 , 5 , padding = 2),
MaxPool2d(2),
Conv2d(32 , 64 , 5 , padding = 2),
MaxPool2d(2),
Flatten(),
Linear(1024 , 64),
Linear(64 , 10),
)
def forward(self , x):
x = self.model(x)
return x
dataset = torchvision.datasets.CIFAR10('E:\\PyCharm_Project\\Pytorch_2.3.1\\PytorchVision\\dataset',train = False , transform = torchvision.transforms.ToTensor(),download = True )
dataloader = DataLoader(dataset , batch_size = 64)
loss = nn.CrossEntropyLoss()
z = Zilliax()
# 随机梯度下降
"""
torch.optim.SGD(parameters() , lr , )
lr是learning rate,设置太大,模型训练并不稳定,设置太小,模型又比较慢,一般开始大参数,后面小参数,在SGD里建议是小数点后面1个0就可以
"""
# optimizer = torch.optim.SGD(z.parameters() ,lr = 0.01)
# Adam
"""
Adam 结合了 RMSProp 和动量法的思想,通过自适应调整学习率,同时维护了动量的一阶和二阶矩估计。
优点:适应性强,对不同类型的数据表现良好;对超参数不太敏感,默认参数通常效果不错。
缺点:有可能会陷入局部最优且收敛慢,特别是在大规模数据集上。
适用场景:广泛应用于各种深度学习任务,是目前最常用的优化器之一。
lr建议是小数点后2个0,lr = 0.01时,loss到epoch16就开始不稳定上升
"""
optimizer = optim.Adam(z.parameters() , lr = 0.001)
for epoch in range(20):
running_loss = 0.0
for data in dataloader:
imgs,targets = data
outputs = z(imgs)
result = loss(outputs , targets) # 交叉熵自带softmax
optimizer.zero_grad() # 重新设置梯度,清空上一次存入的梯度
result.backward()
optimizer.step()
running_loss = running_loss + result
print(running_loss)
因篇幅问题不能全部显示,请点此查看更多更全内容