[bilibili乱搞]LeNet原理及实现

就是LeNet嘛qwq
pytorch官网教程那个!!
内容包含:

  • LeNet模型及实现
  • 用数据集进行验证
  • 用代码下载数据集太慢的解决办法

模型部分

图走丢了qwq
原理差不多就是这样
下面是模型部分的代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5) # input_channel,output_channel,kernel_size
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = F.relu(self.conv1(x)) # input(3,32,32) output(16,28,28)
x = self.pool1(x) # output(16,14,14)
x = F.relu(self.conv2(x)) # output(32,10,10)
x = self.pool2(x) # output(32,5,5)
x = x.view(-1, 32 * 5 * 5)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = F.relu(self.fc3(x)) # output(10)
return x


# 测试
# import torch
# input1 = torch.rand([32, 3, 32, 32])
# model = LeNet()
# print(model)
# output = model(input1)

训练数据

下载数据集

1
2
3
4
5
6
7
8
# 预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# # 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./LeNet/data', train=True,
download=True, transform=transform)

其中transforms.Compose()将我们所使用的预处理方法打包成整体,transforms.ToTensor()将图片或numpy转换为tensor,transforms.Normalize()第一个参数为均值,第二个参数为标准差,输出结果为(原始值-均值)/标准差
但是,官网下载速度实在感人,所以不得不换其他方法……
具体办法:https://blog.csdn.net/qq_43280818/article/details/104241326
将所下载的数据集.tar.gz文件替换到./data文件夹即可
注意将数据集替换后,不要将download的True改成False,否则会出锅

显示图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

# 预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# # 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./LeNet/data', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=True, num_workers=0)

# torchvision.datasets.后面的都是数据集qwq

# # 10000张验证图片
val_set = torchvision.datasets.CIFAR10(root='./LeNet/data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=4,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader) # 迭代器,可用next访问
val_image, val_label = val_data_iter.next()


# 元组类型,index从0开始
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0))) # height,width,channel
# channel=0???
plt.show()


# print labels
print(' '.join('%5s' % classes[val_label[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(val_image))

代码运行结果:

图丢了qwq

进行训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms

# 预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# # 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=True, num_workers=0)

# torchvision.datasets.后面的都是数据集qwq

# # 10000张验证图片
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=10000,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader) # 迭代器,可用next访问
val_image, val_label = val_data_iter.next()


# 元组类型,index从0开始
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = LeNet() # 实例化模型
loss_function = nn.CrossEntropyLoss() # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001) # 优化器

for epoch in range(5): # loop over the dataset multiple times

running_loss = 0.0 # 累计损失
for step, data in enumerate(train_loader, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data

# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item() # 在pytorch里用item取出这个唯一的元素
if step % 500 == 499: # print every 500 mini-batches
with torch.no_grad(): # with是上下文管理器,不计算损失梯度,否则会占用较多内存?qwq
outputs = net(val_image) # [batch, 10]
predict_y = torch.max(outputs, dim=1)[1] # 在维度1寻找最大值,[1]表示只需要index
accuracy = (predict_y == val_label).sum().item() / val_label.size(0) # 准确率

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, running_loss / 500, accuracy))
running_loss = 0.0

print('Finished Training')

save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

记得每计算一个batch,都要调用optimizer.zero_grad(),否则会对计算的历史梯度进行累加(可以通过这个特性变相实现一个很大的batch)
用cpu硬刚真的难顶qwqwq
图丢了qwq
然而并不知道为什么准确率这么低QAQ,人家明明有0.686
图丢了qwq
划完水继续做qwq

预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

transform = transforms.Compose(
[transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = LeNet()
net.load_state_dict(torch.load('Lenet.pth')) # 载入保存的权重文件

im = Image.open('1.jpg')
im = transform(im) # [C, H, W]
im = torch.unsqueeze(im, dim=0) # [N, C, H, W] 在最前面增加一个维度

with torch.no_grad():
outputs = net(im)
predict = torch.max(outputs, dim=1)[1].data.numpy()
print(classes[int(predict)])

尝试运行一下qwq!
在./LeNet目录下存一张可爱的橘猫,叫1.jpg
图丢了qwq
运行一下我萌的代码!!!
upload successful
哦凑为什么猫猫变成了卡车???面壁ing……