3.7 训练和评估⚓︎
我们在完成了模型的训练后,需要在测试集/验证集上完成模型的验证,以确保我们的模型具有泛化能力、不会出现过拟合等问题。在PyTorch中,训练和评估的流程是一致的,只是在训练过程中需要将模型的参数进行更新,而在评估过程中则不需要更新参数。
经过本节的学习,你将收获:
- PyTorch的训练/评估模式的开启
- 完整的训练/评估流程
完成了上述设定后就可以加载数据开始训练模型了。首先应该设置模型的状态:如果是训练状态,那么模型的参数应该支持反向传播的修改;如果是验证/测试状态,则不应该修改模型参数。在PyTorch中,模型的状态设置非常简便,如下的两个操作二选一即可:
model.train() # 训练状态
model.eval() # 验证/测试状态
我们前面在DataLoader构建完成后介绍了如何从中读取数据,在训练过程中使用类似的操作即可,区别在于此时要用for循环读取DataLoader中的全部数据。
for data, label in train_loader:
之后将数据放到GPU上用于后续计算,此处以.cuda()为例
data, label = data.cuda(), label.cuda()
开始用当前批次数据做训练时,应当先将优化器的梯度置零:
optimizer.zero_grad()
之后将data送入模型中训练:
output = model(data)
根据预先定义的criterion计算损失函数:
loss = criterion(output, label)
将loss反向传播回网络:
loss.backward()
使用优化器更新模型参数:
optimizer.step()
这样一个训练过程就完成了,后续还可以计算模型准确率等指标,这部分会在下一节的图像分类实战中加以介绍。
验证/测试的流程基本与训练过程一致,不同点在于:
- 需要预先设置torch.no_grad,以及将model调至eval模式
- 不需要将优化器的梯度置零
- 不需要将loss反向回传到网络
- 不需要更新optimizer
一个完整的图像分类的训练过程如下所示:
def train(epoch):
model.train()
train_loss = 0
for data, label in train_loader:
data, label = data.cuda(), label.cuda()
optimizer.zero_grad()
output = model(data)
loss = criterion(output, label)
loss.backward()
optimizer.step()
train_loss += loss.item()*data.size(0)
train_loss = train_loss/len(train_loader.dataset)
print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
对应的,一个完整图像分类的验证过程如下所示:
def val(epoch):
model.eval()
val_loss = 0
with torch.no_grad():
for data, label in val_loader:
data, label = data.cuda(), label.cuda()
output = model(data)
preds = torch.argmax(output, 1)
loss = criterion(output, label)
val_loss += loss.item()*data.size(0)
running_accu += torch.sum(preds == label.data)
val_loss = val_loss/len(val_loader.dataset)
print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, val_loss))
from sklearn.metrics import classification_report
"""
将下方代码的labels和preds替换为模型预测出来的所有label和preds,
target_names替换为类别名称,
既可得到模型的分类报告
"""
print(classification_report(labels.cpu(), preds.cpu(), target_names=class_names))
torcheval
或torchmetric
来对模型进行评估。
最后更新:
November 30, 2023
创建日期: November 30, 2023
创建日期: November 30, 2023