就是LeNet嘛qwq
pytorch官网教程那个!!
内容包含:
- LeNet模型及实现
- 用数据集进行验证
- 用代码下载数据集太慢的解决办法
模型部分
原理差不多就是这样
下面是模型部分的代码实现1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5) # input_channel,output_channel,kernel_size
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x)) # input(3,32,32) output(16,28,28)
x = self.pool1(x) # output(16,14,14)
x = F.relu(self.conv2(x)) # output(32,10,10)
x = self.pool2(x) # output(32,5,5)
x = x.view(-1, 32 * 5 * 5)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = F.relu(self.fc3(x)) # output(10)
return x
# 测试
# import torch
# input1 = torch.rand([32, 3, 32, 32])
# model = LeNet()
# print(model)
# output = model(input1)
训练数据
下载数据集
1 | # 预处理 |
其中transforms.Compose()将我们所使用的预处理方法打包成整体,transforms.ToTensor()将图片或numpy转换为tensor,transforms.Normalize()第一个参数为均值,第二个参数为标准差,输出结果为(原始值-均值)/标准差
但是,官网下载速度实在感人,所以不得不换其他方法……
具体办法:https://blog.csdn.net/qq_43280818/article/details/104241326
将所下载的数据集.tar.gz文件替换到./data文件夹即可
注意将数据集替换后,不要将download的True改成False,否则会出锅
显示图片
1 | import torch |
代码运行结果:
进行训练
1 | import torch |
记得每计算一个batch,都要调用optimizer.zero_grad(),否则会对计算的历史梯度进行累加(可以通过这个特性变相实现一个很大的batch)
用cpu硬刚真的难顶qwqwq
然而并不知道为什么准确率这么低QAQ,人家明明有0.686
划完水继续做qwq
预测
1 | import torch |
尝试运行一下qwq!
在./LeNet目录下存一张可爱的橘猫,叫1.jpg
运行一下我萌的代码!!!
哦凑为什么猫猫变成了卡车???面壁ing……